From 745f1f39fef57b6ec6883b9d4458116c29bfcdd9 Mon Sep 17 00:00:00 2001 From: Jin Chengcheng Date: Fri, 16 Aug 2024 15:28:43 +0800 Subject: [PATCH] [VL] Fix Arrow ColumnarBatch cannnot revoke rowIterator correctly (#6797) --- .../columnarbatch/ColumnarBatchTest.java | 33 +++++++++ .../gluten/columnarbatch/ColumnarBatches.java | 36 ++-------- .../sql/vectorized/ColumnarBatchUtil.java | 72 +++++++++++++++++++ 3 files changed, 110 insertions(+), 31 deletions(-) create mode 100644 gluten-data/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchUtil.java diff --git a/backends-velox/src/test/java/org/apache/gluten/columnarbatch/ColumnarBatchTest.java b/backends-velox/src/test/java/org/apache/gluten/columnarbatch/ColumnarBatchTest.java index cd2ac50d350c..3b78a4067793 100644 --- a/backends-velox/src/test/java/org/apache/gluten/columnarbatch/ColumnarBatchTest.java +++ b/backends-velox/src/test/java/org/apache/gluten/columnarbatch/ColumnarBatchTest.java @@ -91,6 +91,39 @@ public void testCreateByHandle() { }); } + @Test + public void testOffloadAndLoadReadRow() { + TaskResources$.MODULE$.runUnsafe( + () -> { + final int numRows = 100; + final ColumnarBatch batch = newArrowBatch("a boolean, b int", numRows); + final ArrowWritableColumnVector col0 = (ArrowWritableColumnVector) batch.column(0); + final ArrowWritableColumnVector col1 = (ArrowWritableColumnVector) batch.column(1); + for (int j = 0; j < numRows; j++) { + col0.putBoolean(j, j % 2 == 0); + col1.putInt(j, 15 - j); + } + col1.putNull(numRows - 1); + Assert.assertTrue(ColumnarBatches.isHeavyBatch(batch)); + final ColumnarBatch offloaded = + ColumnarBatches.ensureOffloaded(ArrowBufferAllocators.contextInstance(), batch); + Assert.assertTrue(ColumnarBatches.isLightBatch(offloaded)); + final ColumnarBatch loaded = + ColumnarBatches.ensureLoaded(ArrowBufferAllocators.contextInstance(), offloaded); + Assert.assertTrue(ColumnarBatches.isHeavyBatch(loaded)); + long cnt = + StreamSupport.stream( + Spliterators.spliteratorUnknownSize( + loaded.rowIterator(), Spliterator.ORDERED), + false) + .count(); + Assert.assertEquals(numRows, cnt); + Assert.assertEquals(loaded.getRow(0).getInt(1), 15); + loaded.close(); + return null; + }); + } + private static ColumnarBatch newArrowBatch(String schema, int numRows) { final ArrowWritableColumnVector[] columns = ArrowWritableColumnVector.allocateColumns(numRows, StructType.fromDDL(schema)); diff --git a/gluten-data/src/main/java/org/apache/gluten/columnarbatch/ColumnarBatches.java b/gluten-data/src/main/java/org/apache/gluten/columnarbatch/ColumnarBatches.java index 543e6d4cf97b..d00efd7b80c7 100644 --- a/gluten-data/src/main/java/org/apache/gluten/columnarbatch/ColumnarBatches.java +++ b/gluten-data/src/main/java/org/apache/gluten/columnarbatch/ColumnarBatches.java @@ -16,7 +16,6 @@ */ package org.apache.gluten.columnarbatch; -import org.apache.gluten.exception.GlutenException; import org.apache.gluten.runtime.Runtime; import org.apache.gluten.runtime.Runtimes; import org.apache.gluten.utils.ArrowAbiUtil; @@ -34,24 +33,13 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.spark.sql.vectorized.ColumnarBatchUtil; -import java.lang.reflect.Field; import java.util.Arrays; import java.util.Iterator; import java.util.NoSuchElementException; public class ColumnarBatches { - private static final Field FIELD_COLUMNS; - - static { - try { - Field f = ColumnarBatch.class.getDeclaredField("columns"); - f.setAccessible(true); - FIELD_COLUMNS = f; - } catch (NoSuchFieldException e) { - throw new GlutenException(e); - } - } private ColumnarBatches() {} @@ -90,21 +78,6 @@ private static BatchType identifyBatchType(ColumnarBatch batch) { return BatchType.HEAVY; } - private static void transferVectors(ColumnarBatch from, ColumnarBatch target) { - try { - if (target.numCols() != from.numCols()) { - throw new IllegalStateException(); - } - final ColumnVector[] newVectors = new ColumnVector[from.numCols()]; - for (int i = 0; i < target.numCols(); i++) { - newVectors[i] = from.column(i); - } - FIELD_COLUMNS.set(target, newVectors); - } catch (IllegalAccessException e) { - throw new GlutenException(e); - } - } - /** Heavy batch: Data is readable from JVM and formatted as Arrow data. */ public static boolean isHeavyBatch(ColumnarBatch batch) { return identifyBatchType(batch) == BatchType.HEAVY; @@ -201,8 +174,9 @@ private static ColumnarBatch load(BufferAllocator allocator, ColumnarBatch input } // populate new vectors to input - transferVectors(output, input); - return input; + ColumnarBatchUtil.transferVectors(output, input); + + return output; } } @@ -236,7 +210,7 @@ private static ColumnarBatch offload(BufferAllocator allocator, ColumnarBatch in } // populate new vectors to input - transferVectors(output, input); + ColumnarBatchUtil.transferVectors(output, input); return input; } } diff --git a/gluten-data/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchUtil.java b/gluten-data/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchUtil.java new file mode 100644 index 000000000000..0e2c74813038 --- /dev/null +++ b/gluten-data/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchUtil.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.vectorized; + +import org.apache.gluten.columnarbatch.ColumnarBatches; +import org.apache.gluten.exception.GlutenException; + +import java.lang.reflect.Field; + +public class ColumnarBatchUtil { + + private static final Field FIELD_COLUMNS; + private static final Field FIELD_COLUMNAR_BATCH_ROW; + + static { + try { + Field f = ColumnarBatch.class.getDeclaredField("columns"); + f.setAccessible(true); + FIELD_COLUMNS = f; + Field row = ColumnarBatch.class.getDeclaredField("row"); + row.setAccessible(true); + FIELD_COLUMNAR_BATCH_ROW = row; + } catch (NoSuchFieldException e) { + throw new GlutenException(e); + } + } + + private static void setColumnarBatchRow( + ColumnarBatch from, ColumnVector[] columns, ColumnarBatch target) { + ColumnarBatchRow newRow = new ColumnarBatchRow(columns); + try { + ColumnarBatchRow row = (ColumnarBatchRow) FIELD_COLUMNAR_BATCH_ROW.get(from); + newRow.rowId = row.rowId; + FIELD_COLUMNAR_BATCH_ROW.set(target, newRow); + } catch (IllegalAccessException e) { + throw new GlutenException(e); + } + } + + public static void transferVectors(ColumnarBatch from, ColumnarBatch target) { + try { + if (target.numCols() != from.numCols()) { + throw new IllegalStateException(); + } + final ColumnVector[] newVectors = new ColumnVector[from.numCols()]; + for (int i = 0; i < target.numCols(); i++) { + newVectors[i] = from.column(i); + } + FIELD_COLUMNS.set(target, newVectors); + // Light batch does not need the row. + if (ColumnarBatches.isHeavyBatch(target)) { + setColumnarBatchRow(from, newVectors, target); + } + } catch (IllegalAccessException e) { + throw new GlutenException(e); + } + } +}