diff --git a/gluten-data/src/main/java/org/apache/gluten/columnarbatch/ColumnarBatches.java b/gluten-data/src/main/java/org/apache/gluten/columnarbatch/ColumnarBatches.java index a34cd331b865..d00efd7b80c7 100644 --- a/gluten-data/src/main/java/org/apache/gluten/columnarbatch/ColumnarBatches.java +++ b/gluten-data/src/main/java/org/apache/gluten/columnarbatch/ColumnarBatches.java @@ -16,7 +16,6 @@ */ package org.apache.gluten.columnarbatch; -import org.apache.gluten.exception.GlutenException; import org.apache.gluten.runtime.Runtime; import org.apache.gluten.runtime.Runtimes; import org.apache.gluten.utils.ArrowAbiUtil; @@ -34,25 +33,13 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.sql.vectorized.ColumnarBatch; -import org.apache.spark.sql.vectorized.ColumnarBatchRowUtil; +import org.apache.spark.sql.vectorized.ColumnarBatchUtil; -import java.lang.reflect.Field; import java.util.Arrays; import java.util.Iterator; import java.util.NoSuchElementException; public class ColumnarBatches { - private static final Field FIELD_COLUMNS; - - static { - try { - Field f = ColumnarBatch.class.getDeclaredField("columns"); - f.setAccessible(true); - FIELD_COLUMNS = f; - } catch (NoSuchFieldException e) { - throw new GlutenException(e); - } - } private ColumnarBatches() {} @@ -91,25 +78,6 @@ private static BatchType identifyBatchType(ColumnarBatch batch) { return BatchType.HEAVY; } - private static void transferVectors(ColumnarBatch from, ColumnarBatch target) { - try { - if (target.numCols() != from.numCols()) { - throw new IllegalStateException(); - } - final ColumnVector[] newVectors = new ColumnVector[from.numCols()]; - for (int i = 0; i < target.numCols(); i++) { - newVectors[i] = from.column(i); - } - FIELD_COLUMNS.set(target, newVectors); - // Light batch does not need the row. - if (isHeavyBatch(target)) { - ColumnarBatchRowUtil.setColumnarBatchRow(newVectors, target); - } - } catch (IllegalAccessException e) { - throw new GlutenException(e); - } - } - /** Heavy batch: Data is readable from JVM and formatted as Arrow data. */ public static boolean isHeavyBatch(ColumnarBatch batch) { return identifyBatchType(batch) == BatchType.HEAVY; @@ -206,7 +174,7 @@ private static ColumnarBatch load(BufferAllocator allocator, ColumnarBatch input } // populate new vectors to input - transferVectors(output, input); + ColumnarBatchUtil.transferVectors(output, input); return output; } @@ -242,7 +210,7 @@ private static ColumnarBatch offload(BufferAllocator allocator, ColumnarBatch in } // populate new vectors to input - transferVectors(output, input); + ColumnarBatchUtil.transferVectors(output, input); return input; } } diff --git a/gluten-data/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchRowUtil.java b/gluten-data/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchUtil.java similarity index 51% rename from gluten-data/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchRowUtil.java rename to gluten-data/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchUtil.java index a7ca37016fa7..0e2c74813038 100644 --- a/gluten-data/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchRowUtil.java +++ b/gluten-data/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchUtil.java @@ -16,16 +16,21 @@ */ package org.apache.spark.sql.vectorized; +import org.apache.gluten.columnarbatch.ColumnarBatches; import org.apache.gluten.exception.GlutenException; import java.lang.reflect.Field; -public class ColumnarBatchRowUtil { +public class ColumnarBatchUtil { + private static final Field FIELD_COLUMNS; private static final Field FIELD_COLUMNAR_BATCH_ROW; static { try { + Field f = ColumnarBatch.class.getDeclaredField("columns"); + f.setAccessible(true); + FIELD_COLUMNS = f; Field row = ColumnarBatch.class.getDeclaredField("row"); row.setAccessible(true); FIELD_COLUMNAR_BATCH_ROW = row; @@ -34,10 +39,32 @@ public class ColumnarBatchRowUtil { } } - public static void setColumnarBatchRow(ColumnVector[] columns, ColumnarBatch target) { - ColumnarBatchRow row = new ColumnarBatchRow(columns); + private static void setColumnarBatchRow( + ColumnarBatch from, ColumnVector[] columns, ColumnarBatch target) { + ColumnarBatchRow newRow = new ColumnarBatchRow(columns); try { - FIELD_COLUMNAR_BATCH_ROW.set(target, row); + ColumnarBatchRow row = (ColumnarBatchRow) FIELD_COLUMNAR_BATCH_ROW.get(from); + newRow.rowId = row.rowId; + FIELD_COLUMNAR_BATCH_ROW.set(target, newRow); + } catch (IllegalAccessException e) { + throw new GlutenException(e); + } + } + + public static void transferVectors(ColumnarBatch from, ColumnarBatch target) { + try { + if (target.numCols() != from.numCols()) { + throw new IllegalStateException(); + } + final ColumnVector[] newVectors = new ColumnVector[from.numCols()]; + for (int i = 0; i < target.numCols(); i++) { + newVectors[i] = from.column(i); + } + FIELD_COLUMNS.set(target, newVectors); + // Light batch does not need the row. + if (ColumnarBatches.isHeavyBatch(target)) { + setColumnarBatchRow(from, newVectors, target); + } } catch (IllegalAccessException e) { throw new GlutenException(e); }