Skip to content

Commit

Permalink
refine code
Browse files Browse the repository at this point in the history
  • Loading branch information
marin-ma committed Jul 19, 2024
1 parent e81efe0 commit 94ceab6
Show file tree
Hide file tree
Showing 13 changed files with 275 additions and 275 deletions.
36 changes: 28 additions & 8 deletions cpp/core/shuffle/LocalPartitionWriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,13 @@ class LocalPartitionWriter::LocalSpiller {
arrow::Status spill(uint32_t partitionId, std::unique_ptr<BlockPayload> payload) {
// Check spill Type.
ARROW_RETURN_IF(
payload->type() != Payload::kUncompressed && payload->type() != Payload::kRaw,
payload->type() == Payload::kToBeCompressed,
arrow::Status::Invalid("Cannot spill payload of type: " + payload->toString()));

if (!opened_) {
opened_ = true;
ARROW_ASSIGN_OR_RAISE(os_, arrow::io::FileOutputStream::Open(spillFile_, true));
ARROW_ASSIGN_OR_RAISE(auto raw, arrow::io::FileOutputStream::Open(spillFile_, true));
ARROW_ASSIGN_OR_RAISE(os_, arrow::io::BufferedOutputStream::Create(16384, pool_, raw));
std::cout << "open spill file: " << spillFile_ << std::endl;
diskSpill_ = std::make_unique<Spill>(Spill::SpillType::kSequentialSpill);
}
Expand All @@ -62,8 +63,10 @@ class LocalPartitionWriter::LocalSpiller {
return arrow::Status::OK();
}

auto payloadType = codec_ != nullptr && payload->numRows() >= compressionThreshold_ ? Payload::kToBeCompressed
: Payload::kUncompressed;
auto payloadType = payload->type();
if (payloadType == Payload::kUncompressed && codec_ != nullptr && payload->numRows() >= compressionThreshold_) {
payloadType = Payload::kToBeCompressed;
}
diskSpill_->insertPayload(
partitionId, payloadType, payload->numRows(), payload->isValidityBuffer(), end - start, pool_, codec_);
return arrow::Status::OK();
Expand Down Expand Up @@ -100,7 +103,7 @@ class LocalPartitionWriter::LocalSpiller {
bool opened_{false};
bool finished_{false};
std::shared_ptr<Spill> diskSpill_{nullptr};
std::shared_ptr<arrow::io::FileOutputStream> os_;
std::shared_ptr<arrow::io::OutputStream> os_;
int64_t spillTime_{0};
};

Expand Down Expand Up @@ -504,10 +507,10 @@ arrow::Status LocalPartitionWriter::stop(ShuffleWriterMetrics* metrics) {
return arrow::Status::OK();
}

arrow::Status LocalPartitionWriter::requestSpill(bool stop) {
arrow::Status LocalPartitionWriter::requestSpill(bool isFinal) {
if (!spiller_ || spiller_->finished()) {
std::string spillFile;
if (stop && useSpillFileAsDataFile()) {
if (isFinal && useSpillFileAsDataFile()) {
spillFile = dataFile_;
} else {
ARROW_ASSIGN_OR_RAISE(spillFile, createTempShuffleFile(nextSpilledFileDir()));
Expand All @@ -534,9 +537,26 @@ arrow::Status LocalPartitionWriter::evict(
std::unique_ptr<InMemoryPayload> inMemoryPayload,
Evict::type evictType,
bool reuseBuffers,
bool hasComplexType) {
bool hasComplexType,
bool isFinal) {
rawPartitionLengths_[partitionId] += inMemoryPayload->getBufferSize();

if (evictType == Evict::kSortSpill) {
if (partitionId < lastEvictPid_) {
RETURN_NOT_OK(finishSpill());
}
lastEvictPid_ = partitionId;

RETURN_NOT_OK(requestSpill(isFinal));

auto payloadType = codec_ ? Payload::Type::kCompressed : Payload::Type::kUncompressed;
ARROW_ASSIGN_OR_RAISE(
auto payload,
inMemoryPayload->toBlockPayload(payloadType, payloadPool_.get(), codec_ ? codec_.get() : nullptr));
RETURN_NOT_OK(spiller_->spill(partitionId, std::move(payload)));
return arrow::Status::OK();
}

if (evictType == Evict::kSpill) {
RETURN_NOT_OK(requestSpill(false));
ARROW_ASSIGN_OR_RAISE(
Expand Down
5 changes: 3 additions & 2 deletions cpp/core/shuffle/LocalPartitionWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ class LocalPartitionWriter : public PartitionWriter {
std::unique_ptr<InMemoryPayload> inMemoryPayload,
Evict::type evictType,
bool reuseBuffers,
bool hasComplexType) override;
bool hasComplexType,
bool isFinal) override;

arrow::Status evict(uint32_t partitionId, std::unique_ptr<BlockPayload> blockPayload, bool stop) override;

Expand Down Expand Up @@ -80,7 +81,7 @@ class LocalPartitionWriter : public PartitionWriter {
private:
void init();

arrow::Status requestSpill(bool stop);
arrow::Status requestSpill(bool isFinal);

arrow::Status finishSpill();

Expand Down
5 changes: 3 additions & 2 deletions cpp/core/shuffle/PartitionWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
namespace gluten {

struct Evict {
enum type { kCache, kSpill };
enum type { kCache, kSpill, kSortSpill };
};

class PartitionWriter : public Reclaimable {
Expand All @@ -47,7 +47,8 @@ class PartitionWriter : public Reclaimable {
std::unique_ptr<InMemoryPayload> inMemoryPayload,
Evict::type evictType,
bool reuseBuffers,
bool hasComplexType) = 0;
bool hasComplexType,
bool isFinal) = 0;

virtual arrow::Status evict(uint32_t partitionId, std::unique_ptr<BlockPayload> blockPayload, bool stop) = 0;

Expand Down
106 changes: 45 additions & 61 deletions cpp/core/shuffle/Payload.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,14 @@ T* advance(uint8_t** dst) {
return ptr;
}

arrow::Result<std::pair<uint8_t, uint32_t>> readTypeAndRows(arrow::io::InputStream* inputStream) {
arrow::Result<uint8_t> readType(arrow::io::InputStream* inputStream) {
uint8_t type;
uint32_t numRows;
ARROW_ASSIGN_OR_RAISE(auto bytes, inputStream->Read(sizeof(Payload::Type), &type));
if (bytes == 0) {
// Reach EOS.
return std::make_pair(0, 0);
return 0;
}
RETURN_NOT_OK(inputStream->Read(sizeof(uint32_t), &numRows));
return std::make_pair(type, numRows);
return type;
}

arrow::Result<int64_t> compressBuffer(
Expand Down Expand Up @@ -120,13 +118,16 @@ arrow::Status compressAndFlush(
return arrow::Status::OK();
}

arrow::Result<std::shared_ptr<arrow::Buffer>> readUncompressedBuffer(arrow::io::InputStream* inputStream) {
arrow::Result<std::shared_ptr<arrow::Buffer>> readUncompressedBuffer(
arrow::io::InputStream* inputStream,
arrow::MemoryPool* pool) {
int64_t bufferLength;
RETURN_NOT_OK(inputStream->Read(sizeof(int64_t), &bufferLength));
if (bufferLength == kNullBuffer) {
return nullptr;
}
ARROW_ASSIGN_OR_RAISE(auto buffer, inputStream->Read(bufferLength));
ARROW_ASSIGN_OR_RAISE(auto buffer, arrow::AllocateResizableBuffer(bufferLength, pool));
RETURN_NOT_OK(inputStream->Read(bufferLength, buffer->mutable_data()));
return buffer;
}

Expand All @@ -148,16 +149,15 @@ arrow::Result<std::shared_ptr<arrow::Buffer>> readCompressedBuffer(
RETURN_NOT_OK(inputStream->Read(sizeof(int64_t), &uncompressedLength));
if (compressedLength == kUncompressedBuffer) {
ARROW_ASSIGN_OR_RAISE(auto uncompressed, arrow::AllocateResizableBuffer(uncompressedLength, pool));
RETURN_NOT_OK(inputStream->Read(uncompressedLength, const_cast<uint8_t*>(uncompressed->data())));
RETURN_NOT_OK(inputStream->Read(uncompressedLength, uncompressed->mutable_data()));
return uncompressed;
}
ARROW_ASSIGN_OR_RAISE(auto compressed, arrow::AllocateBuffer(compressedLength, pool));
RETURN_NOT_OK(inputStream->Read(compressedLength, const_cast<uint8_t*>(compressed->data())));
ARROW_ASSIGN_OR_RAISE(auto compressed, arrow::AllocateResizableBuffer(compressedLength, pool));
RETURN_NOT_OK(inputStream->Read(compressedLength, compressed->mutable_data()));

ScopedTimer timer(&decompressTime);
ARROW_ASSIGN_OR_RAISE(auto output, arrow::AllocateResizableBuffer(uncompressedLength, pool));
RETURN_NOT_OK(codec->Decompress(
compressedLength, compressed->data(), uncompressedLength, const_cast<uint8_t*>(output->data())));
RETURN_NOT_OK(codec->Decompress(compressedLength, compressed->data(), uncompressedLength, output->mutable_data()));
return output;
}

Expand Down Expand Up @@ -238,6 +238,8 @@ arrow::Status BlockPayload::serialize(arrow::io::OutputStream* outputStream) {
ScopedTimer timer(&writeTime_);
RETURN_NOT_OK(outputStream->Write(&kUncompressedType, sizeof(Type)));
RETURN_NOT_OK(outputStream->Write(&numRows_, sizeof(uint32_t)));
uint32_t numBuffers = buffers_.size();
RETURN_NOT_OK(outputStream->Write(&numBuffers, sizeof(uint32_t)));
for (auto& buffer : buffers_) {
if (!buffer) {
RETURN_NOT_OK(outputStream->Write(&kNullBuffer, sizeof(int64_t)));
Expand All @@ -255,6 +257,8 @@ arrow::Status BlockPayload::serialize(arrow::io::OutputStream* outputStream) {
ScopedTimer timer(&writeTime_);
RETURN_NOT_OK(outputStream->Write(&kCompressedType, sizeof(Type)));
RETURN_NOT_OK(outputStream->Write(&numRows_, sizeof(uint32_t)));
uint32_t numBuffers = buffers_.size();
RETURN_NOT_OK(outputStream->Write(&numBuffers, sizeof(uint32_t)));
}
for (auto& buffer : buffers_) {
RETURN_NOT_OK(compressAndFlush(std::move(buffer), outputStream, codec_, pool_, compressTime_, writeTime_));
Expand All @@ -264,6 +268,8 @@ arrow::Status BlockPayload::serialize(arrow::io::OutputStream* outputStream) {
ScopedTimer timer(&writeTime_);
RETURN_NOT_OK(outputStream->Write(&kCompressedType, sizeof(Type)));
RETURN_NOT_OK(outputStream->Write(&numRows_, sizeof(uint32_t)));
uint32_t buffers = numBuffers();
RETURN_NOT_OK(outputStream->Write(&buffers, sizeof(uint32_t)));
RETURN_NOT_OK(outputStream->Write(std::move(buffers_[0])));
} break;
case Type::kRaw: {
Expand Down Expand Up @@ -293,58 +299,26 @@ arrow::Result<std::vector<std::shared_ptr<arrow::Buffer>>> BlockPayload::deseria
uint32_t& numRows,
int64_t& decompressTime) {
static const std::vector<std::shared_ptr<arrow::Buffer>> kEmptyBuffers{};
ARROW_ASSIGN_OR_RAISE(auto typeAndRows, readTypeAndRows(inputStream));
if (typeAndRows.first == 0) {
ARROW_ASSIGN_OR_RAISE(auto type, readType(inputStream));
if (type == 0) {
numRows = 0;
return kEmptyBuffers;
}
numRows = typeAndRows.second;
auto fields = schema->fields();
RETURN_NOT_OK(inputStream->Read(sizeof(uint32_t), &numRows));
uint32_t numBuffers;
RETURN_NOT_OK(inputStream->Read(sizeof(uint32_t), &numBuffers));

auto isCompressionEnabled = typeAndRows.first == Type::kCompressed;
auto readBuffer = [&]() {
bool isCompressionEnabled = type == Type::kCompressed;
std::vector<std::shared_ptr<arrow::Buffer>> buffers;
buffers.reserve(numBuffers);
for (auto i = 0; i < numBuffers; ++i) {
buffers.emplace_back();
if (isCompressionEnabled) {
return readCompressedBuffer(inputStream, codec, pool, decompressTime);
ARROW_ASSIGN_OR_RAISE(buffers.back(), readCompressedBuffer(inputStream, codec, pool, decompressTime));
} else {
return readUncompressedBuffer(inputStream);
}
};

bool hasComplexDataType = false;
std::vector<std::shared_ptr<arrow::Buffer>> buffers;
for (const auto& field : fields) {
auto fieldType = field->type()->id();
switch (fieldType) {
case arrow::BinaryType::type_id:
case arrow::StringType::type_id: {
buffers.emplace_back();
ARROW_ASSIGN_OR_RAISE(buffers.back(), readBuffer());
buffers.emplace_back();
ARROW_ASSIGN_OR_RAISE(buffers.back(), readBuffer());
buffers.emplace_back();
ARROW_ASSIGN_OR_RAISE(buffers.back(), readBuffer());
break;
}
case arrow::StructType::type_id:
case arrow::MapType::type_id:
case arrow::ListType::type_id: {
hasComplexDataType = true;
} break;
case arrow::NullType::type_id:
break;
default: {
buffers.emplace_back();
ARROW_ASSIGN_OR_RAISE(buffers.back(), readBuffer());
buffers.emplace_back();
ARROW_ASSIGN_OR_RAISE(buffers.back(), readBuffer());
break;
}
ARROW_ASSIGN_OR_RAISE(buffers.back(), readUncompressedBuffer(inputStream, pool));
}
}
if (hasComplexDataType) {
buffers.emplace_back();
ARROW_ASSIGN_OR_RAISE(buffers.back(), readBuffer());
}
return buffers;
}

Expand Down Expand Up @@ -500,13 +474,23 @@ arrow::Status UncompressedDiskBlockPayload::serialize(arrow::io::OutputStream* o
arrow::Status::Invalid(
"Invalid payload type: " + std::to_string(type_) +
", should be either Payload::kUncompressed or Payload::kToBeCompressed"));
ARROW_ASSIGN_OR_RAISE(auto startPos, inputStream_->Tell());
auto typeAndRows = readTypeAndRows(inputStream_);
// Discard type and rows.
RETURN_NOT_OK(typeAndRows.status());
RETURN_NOT_OK(outputStream->Write(&kCompressedType, sizeof(kCompressedType)));
RETURN_NOT_OK(outputStream->Write(&numRows_, sizeof(uint32_t)));
auto readPos = startPos + sizeof(kUncompressedType) + sizeof(uint32_t);

ARROW_ASSIGN_OR_RAISE(auto startPos, inputStream_->Tell());

// Discard original type and rows.
Payload::Type type;
uint32_t numRows;
ARROW_ASSIGN_OR_RAISE(auto bytes, inputStream_->Read(sizeof(Payload::Type), &type));
ARROW_ASSIGN_OR_RAISE(bytes, inputStream_->Read(sizeof(uint32_t), &numRows));
uint32_t numBuffers = 0;
ARROW_ASSIGN_OR_RAISE(bytes, inputStream_->Read(sizeof(uint32_t), &numBuffers));
ARROW_RETURN_IF(bytes == 0 || numBuffers == 0, arrow::Status::Invalid("Cannot serialize payload with 0 buffers."));
RETURN_NOT_OK(outputStream->Write(&numBuffers, sizeof(uint32_t)));

// Advance Payload::Type, rows and numBuffers.
auto readPos = startPos + sizeof(Payload::Type) + sizeof(uint32_t) + sizeof(uint32_t);
while (readPos - startPos < rawSize_) {
ARROW_ASSIGN_OR_RAISE(auto uncompressed, readUncompressedBuffer());
ARROW_ASSIGN_OR_RAISE(readPos, inputStream_->Tell());
Expand Down
2 changes: 1 addition & 1 deletion cpp/core/shuffle/Payload.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class Payload {
}

uint32_t numBuffers() {
return isValidityBuffer_->size();
return isValidityBuffer_ ? isValidityBuffer_->size() : 1;
}

const std::vector<bool>* isValidityBuffer() const {
Expand Down
7 changes: 3 additions & 4 deletions cpp/core/shuffle/rss/RssPartitionWriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,10 @@ arrow::Status RssPartitionWriter::evict(
std::unique_ptr<InMemoryPayload> inMemoryPayload,
Evict::type evictType,
bool reuseBuffers,
bool hasComplexType) {
bool hasComplexType,
bool isFinal) {
rawPartitionLengths_[partitionId] += inMemoryPayload->getBufferSize();
auto payloadType = (codec_ && inMemoryPayload->numRows() >= options_.compressionThreshold)
? Payload::Type::kCompressed
: Payload::Type::kUncompressed;
auto payloadType = codec_ ? Payload::Type::kCompressed : Payload::Type::kUncompressed;
ARROW_ASSIGN_OR_RAISE(
auto payload, inMemoryPayload->toBlockPayload(payloadType, payloadPool_.get(), codec_ ? codec_.get() : nullptr));
// Copy payload to arrow buffered os.
Expand Down
3 changes: 2 additions & 1 deletion cpp/core/shuffle/rss/RssPartitionWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ class RssPartitionWriter final : public PartitionWriter {
std::unique_ptr<InMemoryPayload> inMemoryPayload,
Evict::type evictType,
bool reuseBuffers,
bool hasComplexType) override;
bool hasComplexType,
bool isFinal) override;

arrow::Status evict(uint32_t partitionId, std::unique_ptr<BlockPayload> blockPayload, bool stop) override;

Expand Down
8 changes: 6 additions & 2 deletions cpp/velox/shuffle/VeloxHashBasedShuffleWriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,10 @@ arrow::Status VeloxHashBasedShuffleWriter::initColumnTypes(const facebook::velox
}
}

if (hasComplexType_) {
isValidityBuffer_.push_back(false);
}

fixedWidthColumnCount_ = simpleColumnIndices_.size();

simpleColumnIndices_.insert(simpleColumnIndices_.end(), binaryColumnIndices_.begin(), binaryColumnIndices_.end());
Expand Down Expand Up @@ -949,7 +953,7 @@ arrow::Status VeloxHashBasedShuffleWriter::evictBuffers(
if (!buffers.empty()) {
auto payload = std::make_unique<InMemoryPayload>(numRows, &isValidityBuffer_, std::move(buffers));
RETURN_NOT_OK(
partitionWriter_->evict(partitionId, std::move(payload), Evict::kCache, reuseBuffers, hasComplexType_));
partitionWriter_->evict(partitionId, std::move(payload), Evict::kCache, reuseBuffers, hasComplexType_, false));
}
return arrow::Status::OK();
}
Expand Down Expand Up @@ -1360,7 +1364,7 @@ arrow::Result<int64_t> VeloxHashBasedShuffleWriter::evictPartitionBuffersMinSize
auto pid = item.first;
ARROW_ASSIGN_OR_RAISE(auto buffers, assembleBuffers(pid, false));
auto payload = std::make_unique<InMemoryPayload>(item.second, &isValidityBuffer_, std::move(buffers));
RETURN_NOT_OK(partitionWriter_->evict(pid, std::move(payload), Evict::kSpill, false, hasComplexType_));
RETURN_NOT_OK(partitionWriter_->evict(pid, std::move(payload), Evict::kSpill, false, hasComplexType_, false));
evicted = beforeEvict - partitionBufferPool_->bytes_allocated();
if (evicted >= size) {
break;
Expand Down
Loading

0 comments on commit 94ceab6

Please sign in to comment.