Skip to content

Commit

Permalink
[VL] Make ColumnarBatch::getRowBytes leak-safe (#6002)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhztheplayer authored Jun 6, 2024
1 parent d1b3e99 commit f72349e
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 36 deletions.
29 changes: 11 additions & 18 deletions cpp/core/jni/JniWrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ static jclass shuffleReaderMetricsClass;
static jmethodID shuffleReaderMetricsSetDecompressTime;
static jmethodID shuffleReaderMetricsSetDeserializeTime;

static jclass block_stripes_class;
static jmethodID block_stripes_constructor;
static jclass blockStripesClass;
static jmethodID blockStripesConstructor;

class JavaInputStreamAdaptor final : public arrow::io::InputStream {
public:
Expand Down Expand Up @@ -280,9 +280,9 @@ jint JNI_OnLoad(JavaVM* vm, void* reserved) {
shuffleReaderMetricsSetDeserializeTime =
getMethodIdOrError(env, shuffleReaderMetricsClass, "setDeserializeTime", "(J)V");

block_stripes_class =
blockStripesClass =
createGlobalClassReferenceOrError(env, "Lorg/apache/spark/sql/execution/datasources/BlockStripes;");
block_stripes_constructor = env->GetMethodID(block_stripes_class, "<init>", "(J[J[II[B)V");
blockStripesConstructor = env->GetMethodID(blockStripesClass, "<init>", "(J[J[II[B)V");

return jniVersion;
}
Expand All @@ -297,7 +297,7 @@ void JNI_OnUnload(JavaVM* vm, void* reserved) {
env->DeleteGlobalRef(nativeColumnarToRowInfoClass);
env->DeleteGlobalRef(byteArrayClass);
env->DeleteGlobalRef(shuffleReaderMetricsClass);
env->DeleteGlobalRef(block_stripes_class);
env->DeleteGlobalRef(blockStripesClass);

gluten::getJniErrorState()->close();
gluten::getJniCommonState()->close();
Expand Down Expand Up @@ -1224,30 +1224,23 @@ Java_org_apache_gluten_datasource_DatasourceJniWrapper_splitBlockByPartitionAndB
}

MemoryManager* memoryManager = reinterpret_cast<MemoryManager*>(memoryManagerId);
auto result = batch->getRowBytes(0);
auto rowBytes = result.first;
auto result = batch->toUnsafeRow(0);
auto rowBytes = result.data();
auto newBatchHandle = ctx->objectStore()->save(ctx->select(memoryManager, batch, partitionColIndiceVec));

auto bytesSize = result.second;
auto bytesSize = result.size();
jbyteArray bytesArray = env->NewByteArray(bytesSize);
env->SetByteArrayRegion(bytesArray, 0, bytesSize, reinterpret_cast<jbyte*>(rowBytes));
delete[] rowBytes;

jlongArray batchArray = env->NewLongArray(1);
long* cBatchArray = new long[1];
cBatchArray[0] = newBatchHandle;
env->SetLongArrayRegion(batchArray, 0, 1, cBatchArray);
delete[] cBatchArray;

jobject block_stripes = env->NewObject(
block_stripes_class,
block_stripes_constructor,
batchHandle,
batchArray,
nullptr,
batch->numColumns(),
bytesArray);
return block_stripes;
jobject blockStripes = env->NewObject(
blockStripesClass, blockStripesConstructor, batchHandle, batchArray, nullptr, batch->numColumns(), bytesArray);
return blockStripes;
JNI_METHOD_END(nullptr)
}

Expand Down
16 changes: 8 additions & 8 deletions cpp/core/memory/ColumnarBatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ int64_t ColumnarBatch::getExportNanos() const {
return exportNanos_;
}

std::pair<char*, int> ColumnarBatch::getRowBytes(int32_t rowId) const {
throw gluten::GlutenException("Not implemented getRowBytes for ColumnarBatch");
std::vector<char> ColumnarBatch::toUnsafeRow(int32_t rowId) const {
throw gluten::GlutenException("Not implemented toUnsafeRow for ColumnarBatch");
}

std::ostream& operator<<(std::ostream& os, const ColumnarBatch& columnarBatch) {
Expand Down Expand Up @@ -86,8 +86,8 @@ std::shared_ptr<ArrowArray> ArrowColumnarBatch::exportArrowArray() {
return cArray;
}

std::pair<char*, int> ArrowColumnarBatch::getRowBytes(int32_t rowId) const {
throw gluten::GlutenException("Not implemented getRowBytes for ArrowColumnarBatch");
std::vector<char> ArrowColumnarBatch::toUnsafeRow(int32_t rowId) const {
throw gluten::GlutenException("#toUnsafeRow of ArrowColumnarBatch is not implemented");
}

ArrowCStructColumnarBatch::ArrowCStructColumnarBatch(
Expand Down Expand Up @@ -123,8 +123,8 @@ std::shared_ptr<ArrowArray> ArrowCStructColumnarBatch::exportArrowArray() {
return cArray_;
}

std::pair<char*, int> ArrowCStructColumnarBatch::getRowBytes(int32_t rowId) const {
throw gluten::GlutenException("Not implemented getRowBytes for ArrowCStructColumnarBatch");
std::vector<char> ArrowCStructColumnarBatch::toUnsafeRow(int32_t rowId) const {
throw gluten::GlutenException("#toUnsafeRow of ArrowCStructColumnarBatch is not implemented");
}

std::shared_ptr<ColumnarBatch> CompositeColumnarBatch::create(std::vector<std::shared_ptr<ColumnarBatch>> batches) {
Expand Down Expand Up @@ -171,8 +171,8 @@ const std::vector<std::shared_ptr<ColumnarBatch>>& CompositeColumnarBatch::getBa
return batches_;
}

std::pair<char*, int> CompositeColumnarBatch::getRowBytes(int32_t rowId) const {
throw gluten::GlutenException("Not implemented getRowBytes for CompositeColumnarBatch");
std::vector<char> CompositeColumnarBatch::toUnsafeRow(int32_t rowId) const {
throw gluten::GlutenException("#toUnsafeRow of CompositeColumnarBatch is not implemented");
}

CompositeColumnarBatch::CompositeColumnarBatch(
Expand Down
9 changes: 5 additions & 4 deletions cpp/core/memory/ColumnarBatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ class ColumnarBatch {

virtual int64_t getExportNanos() const;

virtual std::pair<char*, int> getRowBytes(int32_t rowId) const;
// Serializes one single row to byte array that can be accessed as Spark-compatible unsafe row.
virtual std::vector<char> toUnsafeRow(int32_t rowId) const;

friend std::ostream& operator<<(std::ostream& os, const ColumnarBatch& columnarBatch);

Expand All @@ -75,7 +76,7 @@ class ArrowColumnarBatch final : public ColumnarBatch {

std::shared_ptr<ArrowArray> exportArrowArray() override;

std::pair<char*, int> getRowBytes(int32_t rowId) const override;
std::vector<char> toUnsafeRow(int32_t rowId) const override;

private:
std::shared_ptr<arrow::RecordBatch> batch_;
Expand All @@ -95,7 +96,7 @@ class ArrowCStructColumnarBatch final : public ColumnarBatch {

std::shared_ptr<ArrowArray> exportArrowArray() override;

std::pair<char*, int> getRowBytes(int32_t rowId) const override;
std::vector<char> toUnsafeRow(int32_t rowId) const override;

private:
std::shared_ptr<ArrowSchema> cSchema_ = std::make_shared<ArrowSchema>();
Expand All @@ -120,7 +121,7 @@ class CompositeColumnarBatch final : public ColumnarBatch {

const std::vector<std::shared_ptr<ColumnarBatch>>& getBatches() const;

std::pair<char*, int> getRowBytes(int32_t rowId) const override;
std::vector<char> toUnsafeRow(int32_t rowId) const override;

private:
explicit CompositeColumnarBatch(
Expand Down
10 changes: 5 additions & 5 deletions cpp/velox/memory/VeloxColumnarBatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,13 @@ std::shared_ptr<ColumnarBatch> VeloxColumnarBatch::select(
return std::make_shared<VeloxColumnarBatch>(rowVector);
}

std::pair<char*, int> VeloxColumnarBatch::getRowBytes(int32_t rowId) const {
std::vector<char> VeloxColumnarBatch::toUnsafeRow(int32_t rowId) const {
auto fast = std::make_unique<facebook::velox::row::UnsafeRowFast>(rowVector_);
auto size = fast->rowSize(rowId);
char* rowBytes = new char[size];
std::memset(rowBytes, 0, size);
fast->serialize(0, rowBytes);
return std::make_pair(rowBytes, size);
std::vector<char> bytes(size);
std::memset(bytes.data(), 0, bytes.size());
fast->serialize(0, bytes.data());
return bytes;
}

} // namespace gluten
2 changes: 1 addition & 1 deletion cpp/velox/memory/VeloxColumnarBatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class VeloxColumnarBatch final : public ColumnarBatch {

std::shared_ptr<ArrowSchema> exportArrowSchema() override;
std::shared_ptr<ArrowArray> exportArrowArray() override;
std::pair<char*, int> getRowBytes(int32_t rowId) const override;
std::vector<char> toUnsafeRow(int32_t rowId) const override;
std::shared_ptr<ColumnarBatch> select(facebook::velox::memory::MemoryPool* pool, std::vector<int32_t> columnIndices);
facebook::velox::RowVectorPtr getRowVector() const;
facebook::velox::RowVectorPtr getFlattenedRowVector();
Expand Down

0 comments on commit f72349e

Please sign in to comment.