From f72349ed8b18b40b45428a2c11bb658988c8e97c Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Thu, 6 Jun 2024 15:32:39 +0800 Subject: [PATCH] [VL] Make ColumnarBatch::getRowBytes leak-safe (#6002) --- cpp/core/jni/JniWrapper.cc | 29 ++++++++++---------------- cpp/core/memory/ColumnarBatch.cc | 16 +++++++------- cpp/core/memory/ColumnarBatch.h | 9 ++++---- cpp/velox/memory/VeloxColumnarBatch.cc | 10 ++++----- cpp/velox/memory/VeloxColumnarBatch.h | 2 +- 5 files changed, 30 insertions(+), 36 deletions(-) diff --git a/cpp/core/jni/JniWrapper.cc b/cpp/core/jni/JniWrapper.cc index f5a6c4bd70d0..db498f43adbf 100644 --- a/cpp/core/jni/JniWrapper.cc +++ b/cpp/core/jni/JniWrapper.cc @@ -72,8 +72,8 @@ static jclass shuffleReaderMetricsClass; static jmethodID shuffleReaderMetricsSetDecompressTime; static jmethodID shuffleReaderMetricsSetDeserializeTime; -static jclass block_stripes_class; -static jmethodID block_stripes_constructor; +static jclass blockStripesClass; +static jmethodID blockStripesConstructor; class JavaInputStreamAdaptor final : public arrow::io::InputStream { public: @@ -280,9 +280,9 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) { shuffleReaderMetricsSetDeserializeTime = getMethodIdOrError(env, shuffleReaderMetricsClass, "setDeserializeTime", "(J)V"); - block_stripes_class = + blockStripesClass = createGlobalClassReferenceOrError(env, "Lorg/apache/spark/sql/execution/datasources/BlockStripes;"); - block_stripes_constructor = env->GetMethodID(block_stripes_class, "", "(J[J[II[B)V"); + blockStripesConstructor = env->GetMethodID(blockStripesClass, "", "(J[J[II[B)V"); return jniVersion; } @@ -297,7 +297,7 @@ void JNI_OnUnload(JavaVM* vm, void* reserved) { env->DeleteGlobalRef(nativeColumnarToRowInfoClass); env->DeleteGlobalRef(byteArrayClass); env->DeleteGlobalRef(shuffleReaderMetricsClass); - env->DeleteGlobalRef(block_stripes_class); + env->DeleteGlobalRef(blockStripesClass); gluten::getJniErrorState()->close(); gluten::getJniCommonState()->close(); @@ -1224,14 +1224,13 @@ Java_org_apache_gluten_datasource_DatasourceJniWrapper_splitBlockByPartitionAndB } MemoryManager* memoryManager = reinterpret_cast(memoryManagerId); - auto result = batch->getRowBytes(0); - auto rowBytes = result.first; + auto result = batch->toUnsafeRow(0); + auto rowBytes = result.data(); auto newBatchHandle = ctx->objectStore()->save(ctx->select(memoryManager, batch, partitionColIndiceVec)); - auto bytesSize = result.second; + auto bytesSize = result.size(); jbyteArray bytesArray = env->NewByteArray(bytesSize); env->SetByteArrayRegion(bytesArray, 0, bytesSize, reinterpret_cast(rowBytes)); - delete[] rowBytes; jlongArray batchArray = env->NewLongArray(1); long* cBatchArray = new long[1]; @@ -1239,15 +1238,9 @@ Java_org_apache_gluten_datasource_DatasourceJniWrapper_splitBlockByPartitionAndB env->SetLongArrayRegion(batchArray, 0, 1, cBatchArray); delete[] cBatchArray; - jobject block_stripes = env->NewObject( - block_stripes_class, - block_stripes_constructor, - batchHandle, - batchArray, - nullptr, - batch->numColumns(), - bytesArray); - return block_stripes; + jobject blockStripes = env->NewObject( + blockStripesClass, blockStripesConstructor, batchHandle, batchArray, nullptr, batch->numColumns(), bytesArray); + return blockStripes; JNI_METHOD_END(nullptr) } diff --git a/cpp/core/memory/ColumnarBatch.cc b/cpp/core/memory/ColumnarBatch.cc index bb80510ee351..23567535d50a 100644 --- a/cpp/core/memory/ColumnarBatch.cc +++ b/cpp/core/memory/ColumnarBatch.cc @@ -43,8 +43,8 @@ int64_t ColumnarBatch::getExportNanos() const { return exportNanos_; } -std::pair ColumnarBatch::getRowBytes(int32_t rowId) const { - throw gluten::GlutenException("Not implemented getRowBytes for ColumnarBatch"); +std::vector ColumnarBatch::toUnsafeRow(int32_t rowId) const { + throw gluten::GlutenException("Not implemented toUnsafeRow for ColumnarBatch"); } std::ostream& operator<<(std::ostream& os, const ColumnarBatch& columnarBatch) { @@ -86,8 +86,8 @@ std::shared_ptr ArrowColumnarBatch::exportArrowArray() { return cArray; } -std::pair ArrowColumnarBatch::getRowBytes(int32_t rowId) const { - throw gluten::GlutenException("Not implemented getRowBytes for ArrowColumnarBatch"); +std::vector ArrowColumnarBatch::toUnsafeRow(int32_t rowId) const { + throw gluten::GlutenException("#toUnsafeRow of ArrowColumnarBatch is not implemented"); } ArrowCStructColumnarBatch::ArrowCStructColumnarBatch( @@ -123,8 +123,8 @@ std::shared_ptr ArrowCStructColumnarBatch::exportArrowArray() { return cArray_; } -std::pair ArrowCStructColumnarBatch::getRowBytes(int32_t rowId) const { - throw gluten::GlutenException("Not implemented getRowBytes for ArrowCStructColumnarBatch"); +std::vector ArrowCStructColumnarBatch::toUnsafeRow(int32_t rowId) const { + throw gluten::GlutenException("#toUnsafeRow of ArrowCStructColumnarBatch is not implemented"); } std::shared_ptr CompositeColumnarBatch::create(std::vector> batches) { @@ -171,8 +171,8 @@ const std::vector>& CompositeColumnarBatch::getBa return batches_; } -std::pair CompositeColumnarBatch::getRowBytes(int32_t rowId) const { - throw gluten::GlutenException("Not implemented getRowBytes for CompositeColumnarBatch"); +std::vector CompositeColumnarBatch::toUnsafeRow(int32_t rowId) const { + throw gluten::GlutenException("#toUnsafeRow of CompositeColumnarBatch is not implemented"); } CompositeColumnarBatch::CompositeColumnarBatch( diff --git a/cpp/core/memory/ColumnarBatch.h b/cpp/core/memory/ColumnarBatch.h index 4a7b34889f60..fd8189aa6a20 100644 --- a/cpp/core/memory/ColumnarBatch.h +++ b/cpp/core/memory/ColumnarBatch.h @@ -49,7 +49,8 @@ class ColumnarBatch { virtual int64_t getExportNanos() const; - virtual std::pair getRowBytes(int32_t rowId) const; + // Serializes one single row to byte array that can be accessed as Spark-compatible unsafe row. + virtual std::vector toUnsafeRow(int32_t rowId) const; friend std::ostream& operator<<(std::ostream& os, const ColumnarBatch& columnarBatch); @@ -75,7 +76,7 @@ class ArrowColumnarBatch final : public ColumnarBatch { std::shared_ptr exportArrowArray() override; - std::pair getRowBytes(int32_t rowId) const override; + std::vector toUnsafeRow(int32_t rowId) const override; private: std::shared_ptr batch_; @@ -95,7 +96,7 @@ class ArrowCStructColumnarBatch final : public ColumnarBatch { std::shared_ptr exportArrowArray() override; - std::pair getRowBytes(int32_t rowId) const override; + std::vector toUnsafeRow(int32_t rowId) const override; private: std::shared_ptr cSchema_ = std::make_shared(); @@ -120,7 +121,7 @@ class CompositeColumnarBatch final : public ColumnarBatch { const std::vector>& getBatches() const; - std::pair getRowBytes(int32_t rowId) const override; + std::vector toUnsafeRow(int32_t rowId) const override; private: explicit CompositeColumnarBatch( diff --git a/cpp/velox/memory/VeloxColumnarBatch.cc b/cpp/velox/memory/VeloxColumnarBatch.cc index 83428707b320..0d8db312721a 100644 --- a/cpp/velox/memory/VeloxColumnarBatch.cc +++ b/cpp/velox/memory/VeloxColumnarBatch.cc @@ -143,13 +143,13 @@ std::shared_ptr VeloxColumnarBatch::select( return std::make_shared(rowVector); } -std::pair VeloxColumnarBatch::getRowBytes(int32_t rowId) const { +std::vector VeloxColumnarBatch::toUnsafeRow(int32_t rowId) const { auto fast = std::make_unique(rowVector_); auto size = fast->rowSize(rowId); - char* rowBytes = new char[size]; - std::memset(rowBytes, 0, size); - fast->serialize(0, rowBytes); - return std::make_pair(rowBytes, size); + std::vector bytes(size); + std::memset(bytes.data(), 0, bytes.size()); + fast->serialize(0, bytes.data()); + return bytes; } } // namespace gluten diff --git a/cpp/velox/memory/VeloxColumnarBatch.h b/cpp/velox/memory/VeloxColumnarBatch.h index c319b7977c33..6c79f2772d2d 100644 --- a/cpp/velox/memory/VeloxColumnarBatch.h +++ b/cpp/velox/memory/VeloxColumnarBatch.h @@ -41,7 +41,7 @@ class VeloxColumnarBatch final : public ColumnarBatch { std::shared_ptr exportArrowSchema() override; std::shared_ptr exportArrowArray() override; - std::pair getRowBytes(int32_t rowId) const override; + std::vector toUnsafeRow(int32_t rowId) const override; std::shared_ptr select(facebook::velox::memory::MemoryPool* pool, std::vector columnIndices); facebook::velox::RowVectorPtr getRowVector() const; facebook::velox::RowVectorPtr getFlattenedRowVector();