diff --git a/cpp/velox/operators/serializer/VeloxColumnarToRowConverter.cc b/cpp/velox/operators/serializer/VeloxColumnarToRowConverter.cc index f1c094d2562f..04cfa8933484 100644 --- a/cpp/velox/operators/serializer/VeloxColumnarToRowConverter.cc +++ b/cpp/velox/operators/serializer/VeloxColumnarToRowConverter.cc @@ -28,57 +28,47 @@ using namespace facebook; namespace gluten { -void VeloxColumnarToRowConverter::refreshStates( - facebook::velox::RowVectorPtr rowVector, - int64_t rowId, - int64_t memoryThreshold) { +void VeloxColumnarToRowConverter::refreshStates(facebook::velox::RowVectorPtr rowVector, int64_t startRow) { auto vectorLength = rowVector->size(); numCols_ = rowVector->childrenSize(); fast_ = std::make_unique(rowVector); - size_t totalMemorySize = 0; if (auto fixedRowSize = velox::row::UnsafeRowFast::fixedRowSize(velox::asRowType(rowVector->type()))) { - if (memoryThreshold < fixedRowSize.value()) { - memoryThreshold = fixedRowSize.value(); - LOG(WARNING) << "spark.gluten.sql.columnarToRowMemoryThreshold(" + velox::succinctBytes(memoryThreshold) + - ") is too small, it can't hold even one row(" + velox::succinctBytes(fixedRowSize.value()) + ")"; - } + memThreshold_ == std::max(memThreshold_, fixedRowSize.value()); auto rowSize = fixedRowSize.value(); - numRows_ = std::min(memoryThreshold / rowSize, vectorLength - rowId); - totalMemorySize = rowSize * numRows_; + numRows_ = std::min(memThreshold_ / rowSize, vectorLength - startRow); } else { - int64_t i = rowId; - for (; i < vectorLength; ++i) { - auto rowSize = fast_->rowSize(i); - if (UNLIKELY(totalMemorySize + rowSize > memoryThreshold)) { - if (i == rowId) { - memoryThreshold = rowSize; - LOG(WARNING) << "spark.gluten.sql.columnarToRowMemoryThreshold(" + velox::succinctBytes(memoryThreshold) + - ") is too small, it can't hold even one row(" + velox::succinctBytes(rowSize) + ")"; - } + // Calculate the first row size + int64_t totalMemorySize = fast_->rowSize(startRow); + + auto endRow = startRow + 1; + for (; endRow < vectorLength; ++endRow) { + auto rowSize = fast_->rowSize(endRow); + if (UNLIKELY(totalMemorySize + rowSize > memThreshold_)) { break; } else { totalMemorySize += rowSize; } } - numRows_ = i - rowId; + // Make sure the threshold is larger than the first row size + memThreshold_ = std::max(totalMemorySize, memThreshold_); + numRows_ = endRow - startRow; } - if (veloxBuffers_ == nullptr) { - veloxBuffers_ = velox::AlignedBuffer::allocate(memoryThreshold, veloxPool_.get()); + if (nullptr == veloxBuffers_) { + veloxBuffers_ = velox::AlignedBuffer::allocate(memThreshold_, veloxPool_.get()); + } else if (veloxBuffers_->capacity() < memThreshold_) { + velox::AlignedBuffer::reallocate(&veloxBuffers_, memThreshold_); } - if (veloxBuffers_->capacity() < totalMemorySize) { - velox::AlignedBuffer::reallocate(&veloxBuffers_, totalMemorySize); - } - + bufferAddress_ = veloxBuffers_->asMutable(); - memset(bufferAddress_, 0, sizeof(int8_t) * totalMemorySize); + memset(bufferAddress_, 0, sizeof(int8_t) * memThreshold_); } -void VeloxColumnarToRowConverter::convert(std::shared_ptr cb, int64_t rowId, int64_t memoryThreshold) { +void VeloxColumnarToRowConverter::convert(std::shared_ptr cb, int64_t startRow) { auto veloxBatch = VeloxColumnarBatch::from(veloxPool_.get(), cb); - refreshStates(veloxBatch->getRowVector(), rowId, memoryThreshold); + refreshStates(veloxBatch->getRowVector(), startRow); // Initialize the offsets_ , lengths_ lengths_.clear(); @@ -88,7 +78,7 @@ void VeloxColumnarToRowConverter::convert(std::shared_ptr cb, int size_t offset = 0; for (auto i = 0; i < numRows_; ++i) { - auto rowSize = fast_->serialize(rowId + i, (char*)(bufferAddress_ + offset)); + auto rowSize = fast_->serialize(startRow + i, (char*)(bufferAddress_ + offset)); lengths_[i] = rowSize; if (i > 0) { offsets_[i] = offsets_[i - 1] + lengths_[i - 1];