Skip to content

Commit

Permalink
[VL] Fix Arrow ColumnarBatch cannnot revoke rowIterator correctly (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchengchenghh authored Aug 16, 2024
1 parent a038e93 commit 745f1f3
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,39 @@ public void testCreateByHandle() {
});
}

@Test
public void testOffloadAndLoadReadRow() {
TaskResources$.MODULE$.runUnsafe(
() -> {
final int numRows = 100;
final ColumnarBatch batch = newArrowBatch("a boolean, b int", numRows);
final ArrowWritableColumnVector col0 = (ArrowWritableColumnVector) batch.column(0);
final ArrowWritableColumnVector col1 = (ArrowWritableColumnVector) batch.column(1);
for (int j = 0; j < numRows; j++) {
col0.putBoolean(j, j % 2 == 0);
col1.putInt(j, 15 - j);
}
col1.putNull(numRows - 1);
Assert.assertTrue(ColumnarBatches.isHeavyBatch(batch));
final ColumnarBatch offloaded =
ColumnarBatches.ensureOffloaded(ArrowBufferAllocators.contextInstance(), batch);
Assert.assertTrue(ColumnarBatches.isLightBatch(offloaded));
final ColumnarBatch loaded =
ColumnarBatches.ensureLoaded(ArrowBufferAllocators.contextInstance(), offloaded);
Assert.assertTrue(ColumnarBatches.isHeavyBatch(loaded));
long cnt =
StreamSupport.stream(
Spliterators.spliteratorUnknownSize(
loaded.rowIterator(), Spliterator.ORDERED),
false)
.count();
Assert.assertEquals(numRows, cnt);
Assert.assertEquals(loaded.getRow(0).getInt(1), 15);
loaded.close();
return null;
});
}

private static ColumnarBatch newArrowBatch(String schema, int numRows) {
final ArrowWritableColumnVector[] columns =
ArrowWritableColumnVector.allocateColumns(numRows, StructType.fromDDL(schema));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
*/
package org.apache.gluten.columnarbatch;

import org.apache.gluten.exception.GlutenException;
import org.apache.gluten.runtime.Runtime;
import org.apache.gluten.runtime.Runtimes;
import org.apache.gluten.utils.ArrowAbiUtil;
Expand All @@ -34,24 +33,13 @@
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
import org.apache.spark.sql.vectorized.ColumnVector;
import org.apache.spark.sql.vectorized.ColumnarBatch;
import org.apache.spark.sql.vectorized.ColumnarBatchUtil;

import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.Iterator;
import java.util.NoSuchElementException;

public class ColumnarBatches {
private static final Field FIELD_COLUMNS;

static {
try {
Field f = ColumnarBatch.class.getDeclaredField("columns");
f.setAccessible(true);
FIELD_COLUMNS = f;
} catch (NoSuchFieldException e) {
throw new GlutenException(e);
}
}

private ColumnarBatches() {}

Expand Down Expand Up @@ -90,21 +78,6 @@ private static BatchType identifyBatchType(ColumnarBatch batch) {
return BatchType.HEAVY;
}

private static void transferVectors(ColumnarBatch from, ColumnarBatch target) {
try {
if (target.numCols() != from.numCols()) {
throw new IllegalStateException();
}
final ColumnVector[] newVectors = new ColumnVector[from.numCols()];
for (int i = 0; i < target.numCols(); i++) {
newVectors[i] = from.column(i);
}
FIELD_COLUMNS.set(target, newVectors);
} catch (IllegalAccessException e) {
throw new GlutenException(e);
}
}

/** Heavy batch: Data is readable from JVM and formatted as Arrow data. */
public static boolean isHeavyBatch(ColumnarBatch batch) {
return identifyBatchType(batch) == BatchType.HEAVY;
Expand Down Expand Up @@ -201,8 +174,9 @@ private static ColumnarBatch load(BufferAllocator allocator, ColumnarBatch input
}

// populate new vectors to input
transferVectors(output, input);
return input;
ColumnarBatchUtil.transferVectors(output, input);

return output;
}
}

Expand Down Expand Up @@ -236,7 +210,7 @@ private static ColumnarBatch offload(BufferAllocator allocator, ColumnarBatch in
}

// populate new vectors to input
transferVectors(output, input);
ColumnarBatchUtil.transferVectors(output, input);
return input;
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.vectorized;

import org.apache.gluten.columnarbatch.ColumnarBatches;
import org.apache.gluten.exception.GlutenException;

import java.lang.reflect.Field;

public class ColumnarBatchUtil {

private static final Field FIELD_COLUMNS;
private static final Field FIELD_COLUMNAR_BATCH_ROW;

static {
try {
Field f = ColumnarBatch.class.getDeclaredField("columns");
f.setAccessible(true);
FIELD_COLUMNS = f;
Field row = ColumnarBatch.class.getDeclaredField("row");
row.setAccessible(true);
FIELD_COLUMNAR_BATCH_ROW = row;
} catch (NoSuchFieldException e) {
throw new GlutenException(e);
}
}

private static void setColumnarBatchRow(
ColumnarBatch from, ColumnVector[] columns, ColumnarBatch target) {
ColumnarBatchRow newRow = new ColumnarBatchRow(columns);
try {
ColumnarBatchRow row = (ColumnarBatchRow) FIELD_COLUMNAR_BATCH_ROW.get(from);
newRow.rowId = row.rowId;
FIELD_COLUMNAR_BATCH_ROW.set(target, newRow);
} catch (IllegalAccessException e) {
throw new GlutenException(e);
}
}

public static void transferVectors(ColumnarBatch from, ColumnarBatch target) {
try {
if (target.numCols() != from.numCols()) {
throw new IllegalStateException();
}
final ColumnVector[] newVectors = new ColumnVector[from.numCols()];
for (int i = 0; i < target.numCols(); i++) {
newVectors[i] = from.column(i);
}
FIELD_COLUMNS.set(target, newVectors);
// Light batch does not need the row.
if (ColumnarBatches.isHeavyBatch(target)) {
setColumnarBatchRow(from, newVectors, target);
}
} catch (IllegalAccessException e) {
throw new GlutenException(e);
}
}
}

0 comments on commit 745f1f3

Please sign in to comment.