Skip to content

Commit

Permalink
refactor and fix rowId
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchengchenghh committed Aug 16, 2024
1 parent f4661c3 commit 02307b4
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
*/
package org.apache.gluten.columnarbatch;

import org.apache.gluten.exception.GlutenException;
import org.apache.gluten.runtime.Runtime;
import org.apache.gluten.runtime.Runtimes;
import org.apache.gluten.utils.ArrowAbiUtil;
Expand All @@ -34,25 +33,13 @@
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.vectorized.ColumnVector;
import org.apache.spark.sql.vectorized.ColumnarBatch;
import org.apache.spark.sql.vectorized.ColumnarBatchRowUtil;
import org.apache.spark.sql.vectorized.ColumnarBatchUtil;

import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.Iterator;
import java.util.NoSuchElementException;

public class ColumnarBatches {
private static final Field FIELD_COLUMNS;

static {
try {
Field f = ColumnarBatch.class.getDeclaredField("columns");
f.setAccessible(true);
FIELD_COLUMNS = f;
} catch (NoSuchFieldException e) {
throw new GlutenException(e);
}
}

private ColumnarBatches() {}

Expand Down Expand Up @@ -91,25 +78,6 @@ private static BatchType identifyBatchType(ColumnarBatch batch) {
return BatchType.HEAVY;
}

private static void transferVectors(ColumnarBatch from, ColumnarBatch target) {
try {
if (target.numCols() != from.numCols()) {
throw new IllegalStateException();
}
final ColumnVector[] newVectors = new ColumnVector[from.numCols()];
for (int i = 0; i < target.numCols(); i++) {
newVectors[i] = from.column(i);
}
FIELD_COLUMNS.set(target, newVectors);
// Light batch does not need the row.
if (isHeavyBatch(target)) {
ColumnarBatchRowUtil.setColumnarBatchRow(newVectors, target);
}
} catch (IllegalAccessException e) {
throw new GlutenException(e);
}
}

/** Heavy batch: Data is readable from JVM and formatted as Arrow data. */
public static boolean isHeavyBatch(ColumnarBatch batch) {
return identifyBatchType(batch) == BatchType.HEAVY;
Expand Down Expand Up @@ -206,7 +174,7 @@ private static ColumnarBatch load(BufferAllocator allocator, ColumnarBatch input
}

// populate new vectors to input
transferVectors(output, input);
ColumnarBatchUtil.transferVectors(output, input);

return output;
}
Expand Down Expand Up @@ -242,7 +210,7 @@ private static ColumnarBatch offload(BufferAllocator allocator, ColumnarBatch in
}

// populate new vectors to input
transferVectors(output, input);
ColumnarBatchUtil.transferVectors(output, input);
return input;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,21 @@
*/
package org.apache.spark.sql.vectorized;

import org.apache.gluten.columnarbatch.ColumnarBatches;
import org.apache.gluten.exception.GlutenException;

import java.lang.reflect.Field;

public class ColumnarBatchRowUtil {
public class ColumnarBatchUtil {

private static final Field FIELD_COLUMNS;
private static final Field FIELD_COLUMNAR_BATCH_ROW;

static {
try {
Field f = ColumnarBatch.class.getDeclaredField("columns");
f.setAccessible(true);
FIELD_COLUMNS = f;
Field row = ColumnarBatch.class.getDeclaredField("row");
row.setAccessible(true);
FIELD_COLUMNAR_BATCH_ROW = row;
Expand All @@ -34,10 +39,32 @@ public class ColumnarBatchRowUtil {
}
}

public static void setColumnarBatchRow(ColumnVector[] columns, ColumnarBatch target) {
ColumnarBatchRow row = new ColumnarBatchRow(columns);
private static void setColumnarBatchRow(
ColumnarBatch from, ColumnVector[] columns, ColumnarBatch target) {
ColumnarBatchRow newRow = new ColumnarBatchRow(columns);
try {
FIELD_COLUMNAR_BATCH_ROW.set(target, row);
ColumnarBatchRow row = (ColumnarBatchRow) FIELD_COLUMNAR_BATCH_ROW.get(from);
newRow.rowId = row.rowId;
FIELD_COLUMNAR_BATCH_ROW.set(target, newRow);
} catch (IllegalAccessException e) {
throw new GlutenException(e);
}
}

public static void transferVectors(ColumnarBatch from, ColumnarBatch target) {
try {
if (target.numCols() != from.numCols()) {
throw new IllegalStateException();
}
final ColumnVector[] newVectors = new ColumnVector[from.numCols()];
for (int i = 0; i < target.numCols(); i++) {
newVectors[i] = from.column(i);
}
FIELD_COLUMNS.set(target, newVectors);
// Light batch does not need the row.
if (ColumnarBatches.isHeavyBatch(target)) {
setColumnarBatchRow(from, newVectors, target);
}
} catch (IllegalAccessException e) {
throw new GlutenException(e);
}
Expand Down

0 comments on commit 02307b4

Please sign in to comment.