-
Notifications
You must be signed in to change notification settings - Fork 447
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
27 changed files
with
674 additions
and
272 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "shuffle/BlockPayload.h" | ||
|
||
namespace gluten {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,304 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#pragma once | ||
|
||
#include <arrow/buffer.h> | ||
#include "shuffle/Options.h" | ||
#include "shuffle/PartitionWriter.h" | ||
#include "shuffle/Utils.h" | ||
|
||
namespace gluten { | ||
// A block represents data to be cached in-memory or spilled. | ||
// Can be compressed or uncompressed. | ||
|
||
namespace { | ||
|
||
static constexpr int64_t kZeroBufferLength = 0; | ||
static constexpr int64_t kNullBuffer = -1; | ||
static constexpr int64_t kUncompressedBuffer = -2; | ||
|
||
template <typename T> | ||
void write(uint8_t** dst, T data) { | ||
auto ptr = reinterpret_cast<T*>(*dst); | ||
*ptr = data; | ||
*dst += sizeof(T); | ||
} | ||
|
||
template <typename T> | ||
T* advance(uint8_t** dst) { | ||
auto ptr = reinterpret_cast<T*>(*dst); | ||
*dst += sizeof(T); | ||
return ptr; | ||
} | ||
|
||
arrow::Status compressBuffer( | ||
std::shared_ptr<arrow::Buffer>& buffer, | ||
uint8_t*& output, | ||
int64_t outputLength, | ||
ShuffleWriterOptions* options) { | ||
if (!buffer) { | ||
write<int64_t>(&output, kNullBuffer); | ||
write<int64_t>(&output, kZeroBufferLength); | ||
return arrow::Status::OK(); | ||
} | ||
auto* compressedLengthPtr = advance<int64_t>(&output); | ||
write(&output, static_cast<int64_t>(buffer->size())); | ||
ARROW_ASSIGN_OR_RAISE( | ||
auto compressedLength, options->codec->Compress(buffer->size(), buffer->data(), outputLength, output)); | ||
if (compressedLength > buffer->size()) { | ||
// Write uncompressed buffer. | ||
memcpy(output, buffer->data(), buffer->size()); | ||
output += buffer->size(); | ||
*compressedLengthPtr = kUncompressedBuffer; | ||
} else { | ||
output += compressedLength; | ||
*compressedLengthPtr = static_cast<int64_t>(compressedLength); | ||
} | ||
// Release buffer after compression. | ||
buffer = nullptr; | ||
return arrow::Status::OK(); | ||
} | ||
|
||
} // namespace | ||
|
||
class BlockPayload : public Payload { | ||
public: | ||
enum Type : int32_t { kCompressed, kUncompressed }; | ||
|
||
BlockPayload(BlockPayload::Type type, uint32_t numRows, std::vector<std::shared_ptr<arrow::Buffer>> buffers) | ||
: type_(type), numRows_(numRows), buffers_(std::move(buffers)) {} | ||
|
||
static arrow::Result<std::unique_ptr<BlockPayload>> fromBuffers( | ||
uint32_t numRows, | ||
std::vector<std::shared_ptr<arrow::Buffer>> buffers, | ||
ShuffleWriterOptions* options, | ||
bool reuseBuffers, | ||
bool shouldCompressBuffers) { | ||
if (options->codec && numRows >= options->compression_threshold && shouldCompressBuffers) { | ||
// Compress. | ||
// Compressed buffer layout: | buffer1 compressedLength | buffer1 uncompressedLength | buffer1 | ... | ||
auto metadataLength = sizeof(int64_t) * 2 * buffers.size(); | ||
int64_t totalCompressedLength = | ||
std::accumulate(buffers.begin(), buffers.end(), 0LL, [&](auto sum, const auto& buffer) { | ||
if (!buffer) { | ||
return sum; | ||
} | ||
return sum + options->codec->MaxCompressedLen(buffer->size(), buffer->data()); | ||
}); | ||
ARROW_ASSIGN_OR_RAISE( | ||
std::shared_ptr<arrow::ResizableBuffer> compressed, | ||
arrow::AllocateResizableBuffer( | ||
metadataLength + totalCompressedLength, options->ipc_write_options.memory_pool)); | ||
auto output = compressed->mutable_data(); | ||
|
||
// Compress buffers one by one. | ||
for (auto& buffer : buffers) { | ||
auto availableLength = compressed->size() - (output - compressed->data()); | ||
RETURN_NOT_OK(compressBuffer(buffer, output, availableLength, options)); | ||
} | ||
|
||
int64_t actualLength = output - compressed->data(); | ||
ARROW_RETURN_IF(actualLength < 0, arrow::Status::Invalid("Writing compressed buffer out of bound.")); | ||
RETURN_NOT_OK(compressed->Resize(actualLength)); | ||
return std::make_unique<BlockPayload>( | ||
Type::kCompressed, numRows, std::vector<std::shared_ptr<arrow::Buffer>>{compressed}); | ||
} | ||
if (reuseBuffers) { | ||
// Copy. | ||
std::vector<std::shared_ptr<arrow::Buffer>> copies; | ||
for (auto& buffer : buffers) { | ||
if (!buffer) { | ||
copies.push_back(nullptr); | ||
continue; | ||
} | ||
ARROW_ASSIGN_OR_RAISE( | ||
auto copy, arrow::AllocateResizableBuffer(buffer->size(), options->ipc_write_options.memory_pool)); | ||
memcpy(copy->mutable_data(), buffer->data(), buffer->size()); | ||
copies.push_back(std::move(copy)); | ||
} | ||
return std::make_unique<BlockPayload>(Type::kUncompressed, numRows, std::move(copies)); | ||
} | ||
return std::make_unique<BlockPayload>(Type::kUncompressed, numRows, std::move(buffers)); | ||
} | ||
|
||
arrow::Status serialize(arrow::io::OutputStream* outputStream) override { | ||
RETURN_NOT_OK(outputStream->Write(&type_, sizeof(Type))); | ||
RETURN_NOT_OK(outputStream->Write(&numRows_, sizeof(uint32_t))); | ||
if (type_ == Type::kUncompressed) { | ||
for (auto& buffer : buffers_) { | ||
if (!buffer) { | ||
RETURN_NOT_OK(outputStream->Write(&kNullBuffer, sizeof(int64_t))); | ||
continue; | ||
} | ||
int64_t bufferSize = buffer->size(); | ||
RETURN_NOT_OK(outputStream->Write(&bufferSize, sizeof(int64_t))); | ||
RETURN_NOT_OK(outputStream->Write(std::move(buffer))); | ||
} | ||
} else { | ||
RETURN_NOT_OK(outputStream->Write(std::move(buffers_[0]))); | ||
} | ||
buffers_.clear(); | ||
return arrow::Status::OK(); | ||
} | ||
|
||
static arrow::Result<std::vector<std::shared_ptr<arrow::Buffer>>> deserialize( | ||
arrow::io::InputStream* inputStream, | ||
const std::shared_ptr<arrow::Schema>& schema, | ||
const std::shared_ptr<arrow::util::Codec>& codec, | ||
arrow::MemoryPool* pool, | ||
uint32_t& numRows) { | ||
static const std::vector<std::shared_ptr<arrow::Buffer>> kEmptyBuffers{}; | ||
ARROW_ASSIGN_OR_RAISE(auto typeAndRows, readTypeAndRows(inputStream)); | ||
if (typeAndRows.first == kIpcContinuationToken && typeAndRows.second == kZeroLength) { | ||
numRows = 0; | ||
return kEmptyBuffers; | ||
} | ||
numRows = typeAndRows.second; | ||
auto fields = schema->fields(); | ||
|
||
auto isCompressionEnabled = typeAndRows.first == Type::kUncompressed || codec == nullptr; | ||
auto readBuffer = [&]() { | ||
if (isCompressionEnabled) { | ||
return readUncompressedBuffer(inputStream); | ||
} else { | ||
return readCompressedBuffer(inputStream, codec, pool); | ||
} | ||
}; | ||
|
||
bool hasComplexDataType = false; | ||
std::vector<std::shared_ptr<arrow::Buffer>> buffers; | ||
for (const auto& field : fields) { | ||
auto fieldType = field->type()->id(); | ||
switch (fieldType) { | ||
case arrow::BinaryType::type_id: | ||
case arrow::StringType::type_id: { | ||
buffers.emplace_back(); | ||
ARROW_ASSIGN_OR_RAISE(buffers.back(), readBuffer()); | ||
buffers.emplace_back(); | ||
ARROW_ASSIGN_OR_RAISE(buffers.back(), readBuffer()); | ||
buffers.emplace_back(); | ||
ARROW_ASSIGN_OR_RAISE(buffers.back(), readBuffer()); | ||
break; | ||
} | ||
case arrow::StructType::type_id: | ||
case arrow::MapType::type_id: | ||
case arrow::ListType::type_id: { | ||
hasComplexDataType = true; | ||
} break; | ||
default: { | ||
buffers.emplace_back(); | ||
ARROW_ASSIGN_OR_RAISE(buffers.back(), readBuffer()); | ||
buffers.emplace_back(); | ||
ARROW_ASSIGN_OR_RAISE(buffers.back(), readBuffer()); | ||
break; | ||
} | ||
} | ||
} | ||
if (hasComplexDataType) { | ||
buffers.emplace_back(); | ||
ARROW_ASSIGN_OR_RAISE(buffers.back(), readBuffer()); | ||
} | ||
return buffers; | ||
} | ||
|
||
static arrow::Result<std::pair<int32_t, uint32_t>> readTypeAndRows(arrow::io::InputStream* inputStream) { | ||
int32_t type; | ||
uint32_t numRows; | ||
RETURN_NOT_OK(inputStream->Read(sizeof(Type), &type)); | ||
RETURN_NOT_OK(inputStream->Read(sizeof(uint32_t), &numRows)); | ||
return std::make_pair(type, numRows); | ||
} | ||
|
||
static arrow::Status mergeCompressed( | ||
arrow::io::InputStream* inputStream, | ||
arrow::io::OutputStream* outputStream, | ||
uint32_t numRows, | ||
int64_t totalLength) { | ||
static const Type kType = Type::kUncompressed; | ||
RETURN_NOT_OK(outputStream->Write(&kType, sizeof(Type))); | ||
RETURN_NOT_OK(outputStream->Write(&numRows, sizeof(uint32_t))); | ||
RETURN_NOT_OK(outputStream->Write(&totalLength, sizeof(int64_t))); | ||
ARROW_ASSIGN_OR_RAISE(auto buffer, inputStream->Read(totalLength)); | ||
RETURN_NOT_OK(outputStream->Write(buffer)); | ||
return arrow::Status::OK(); | ||
} | ||
|
||
static arrow::Result<std::shared_ptr<arrow::Buffer>> readUncompressedBuffer(arrow::io::InputStream* inputStream) { | ||
int64_t bufferLength; | ||
RETURN_NOT_OK(inputStream->Read(sizeof(int64_t), &bufferLength)); | ||
if (bufferLength == kNullBuffer) { | ||
return nullptr; | ||
} | ||
ARROW_ASSIGN_OR_RAISE(auto buffer, inputStream->Read(bufferLength)); | ||
return buffer; | ||
} | ||
|
||
static arrow::Result<std::shared_ptr<arrow::Buffer>> readCompressedBuffer( | ||
arrow::io::InputStream* inputStream, | ||
const std::shared_ptr<arrow::util::Codec>& codec, | ||
arrow::MemoryPool* pool) { | ||
int64_t compressedLength; | ||
int64_t uncompressedLength; | ||
RETURN_NOT_OK(inputStream->Read(sizeof(int64_t), &compressedLength)); | ||
RETURN_NOT_OK(inputStream->Read(sizeof(int64_t), &uncompressedLength)); | ||
if (compressedLength == kNullBuffer) { | ||
return nullptr; | ||
} | ||
if (compressedLength == kUncompressedBuffer) { | ||
ARROW_ASSIGN_OR_RAISE(auto uncompressed, arrow::AllocateBuffer(uncompressedLength, pool)); | ||
RETURN_NOT_OK(inputStream->Read(uncompressedLength, const_cast<uint8_t*>(uncompressed->data()))); | ||
return uncompressed; | ||
} | ||
ARROW_ASSIGN_OR_RAISE(auto compressed, arrow::AllocateBuffer(compressedLength, pool)); | ||
RETURN_NOT_OK(inputStream->Read(compressedLength, const_cast<uint8_t*>(compressed->data()))); | ||
ARROW_ASSIGN_OR_RAISE(auto output, arrow::AllocateBuffer(uncompressedLength, pool)); | ||
RETURN_NOT_OK(codec->Decompress( | ||
compressedLength, compressed->data(), uncompressedLength, const_cast<uint8_t*>(output->data()))); | ||
return output; | ||
} | ||
|
||
static arrow::Status mergeUncompressed(arrow::io::InputStream* inputStream, arrow::ResizableBuffer* output) { | ||
ARROW_ASSIGN_OR_RAISE(auto input, readUncompressedBuffer(inputStream)); | ||
auto data = output->mutable_data() + output->size(); | ||
auto newSize = output->size() + input->size(); | ||
RETURN_NOT_OK(output->Resize(newSize)); | ||
memcpy(data, input->data(), input->size()); | ||
return arrow::Status::OK(); | ||
} | ||
|
||
static arrow::Status compressAndWrite( | ||
std::shared_ptr<arrow::Buffer> buffer, | ||
arrow::io::OutputStream* outputStream, | ||
ShuffleWriterOptions* options) { | ||
auto maxCompressedLength = options->codec->MaxCompressedLen(buffer->size(), buffer->data()); | ||
ARROW_ASSIGN_OR_RAISE( | ||
std::shared_ptr<arrow::ResizableBuffer> compressed, | ||
arrow::AllocateResizableBuffer( | ||
sizeof(int64_t) * 2 + maxCompressedLength, options->ipc_write_options.memory_pool)); | ||
auto output = compressed->mutable_data(); | ||
RETURN_NOT_OK(compressBuffer(buffer, output, maxCompressedLength, options)); | ||
RETURN_NOT_OK(outputStream->Write(compressed->data(), output - compressed->data())); | ||
return arrow::Status::OK(); | ||
} | ||
|
||
private: | ||
Type type_; | ||
uint32_t numRows_; | ||
std::vector<std::shared_ptr<arrow::Buffer>> buffers_; | ||
}; | ||
|
||
} // namespace gluten |
Oops, something went wrong.