Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
zhztheplayer committed Jun 28, 2024
1 parent 583e9d1 commit 6bc1e07
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.apache.spark.sql.execution

import org.apache.gluten.columnarbatch.ColumnarBatches
import org.apache.gluten.exec.Runtimes
import org.apache.gluten.sql.shims.SparkShimLoader
import org.apache.gluten.vectorized.{ColumnarBatchSerializeResult, ColumnarBatchSerializerJniWrapper}

Expand All @@ -30,7 +31,6 @@ import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.TaskResources

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer;

// Utility methods to convert Vanilla broadcast relations from/to Velox broadcast relations.
Expand Down Expand Up @@ -152,12 +152,11 @@ object BroadcastUtils {
if (filtered.isEmpty) {
return ColumnarBatchSerializeResult.EMPTY
}
val batchRuntime = ColumnarBatches.getRuntime(filtered.toList.asJava)
val handleArray = filtered.map(ColumnarBatches.getNativeHandle)
val serializeResult =
try {
ColumnarBatchSerializerJniWrapper
.create(batchRuntime)
.create(Runtimes.contextInstance("BroadcastUtils#serializeStream"))
.serialize(handleArray)
} finally {
filtered.foreach(ColumnarBatches.release)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import org.apache.gluten.utils.ImplicitClass;
import org.apache.gluten.vectorized.ArrowWritableColumnVector;

import com.google.common.base.Preconditions;
import org.apache.arrow.c.ArrowArray;
import org.apache.arrow.c.ArrowSchema;
import org.apache.arrow.c.CDataDictionaryProvider;
Expand All @@ -37,11 +36,8 @@

import java.lang.reflect.Field;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.NoSuchElementException;
import java.util.Set;

public class ColumnarBatches {
private static final Field FIELD_COLUMNS;
Expand Down Expand Up @@ -131,7 +127,7 @@ public static ColumnarBatch select(ColumnarBatch batch, int[] columnIndices) {
long outputBatchHandle =
ColumnarBatchJniWrapper.create(Runtimes.contextInstance("ColumnarBatches#select"))
.select(iv.handle(), columnIndices);
return create(iv.runtime(), outputBatchHandle);
return create(Runtimes.contextInstance("ColumnarBatches#select"), outputBatchHandle);
case HEAVY:
return new ColumnarBatch(
Arrays.stream(columnIndices).mapToObj(batch::column).toArray(ColumnVector[]::new),
Expand Down Expand Up @@ -179,7 +175,7 @@ private static ColumnarBatch load(BufferAllocator allocator, ColumnarBatch input
ArrowArray cArray = ArrowArray.allocateNew(allocator);
ArrowSchema arrowSchema = ArrowSchema.allocateNew(allocator);
CDataDictionaryProvider provider = new CDataDictionaryProvider()) {
ColumnarBatchJniWrapper.create(iv.runtime())
ColumnarBatchJniWrapper.create(Runtimes.contextInstance("ColumnarBatches#load"))
.exportToArrow(iv.handle(), cSchema.memoryAddress(), cArray.memoryAddress());

Data.exportSchema(
Expand Down Expand Up @@ -331,15 +327,9 @@ public static long compose(ColumnarBatch... batches) {
Arrays.stream(batches)
.map(ColumnarBatches::getIndicatorVector)
.toArray(IndicatorVector[]::new);
// We assume all input batches should be managed by same Runtime.
// FIXME: The check could be removed to adopt ownership-transfer semantic
final Runtime[] runtimes =
Arrays.stream(ivs).map(IndicatorVector::runtime).distinct().toArray(Runtime[]::new);
Preconditions.checkState(
runtimes.length == 1, "All input batches should be managed by same Runtime.");
final Runtime runtime = runtimes[0];
final long[] handles = Arrays.stream(ivs).mapToLong(IndicatorVector::handle).toArray();
return ColumnarBatchJniWrapper.create(runtime).compose(handles);
return ColumnarBatchJniWrapper.create(Runtimes.contextInstance("ColumnarBatches#compose"))
.compose(handles);
}

public static ColumnarBatch create(Runtime runtime, long nativeHandle) {
Expand Down Expand Up @@ -383,18 +373,4 @@ public static void release(ColumnarBatch b) {
public static long getNativeHandle(ColumnarBatch batch) {
return getIndicatorVector(batch).handle();
}

public static Runtime getRuntime(ColumnarBatch batch) {
return getIndicatorVector(batch).runtime();
}

public static Runtime getRuntime(List<ColumnarBatch> batch) {
final Set<Runtime> all = new HashSet<>();
batch.forEach(b -> all.add(getRuntime(b)));
if (all.size() != 1) {
throw new IllegalArgumentException(
"The input columnar batches has different associated runtimes");
}
return all.toArray(new Runtime[0])[0];
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,7 @@ protected IndicatorVector(Runtime runtime, long handle) {
this.handle = handle;
}

public Runtime runtime() {
return runtime;
}

public String getType() {

return jniwrapper.getType(handle);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,10 @@ object ExecUtil {
val newHandle = ColumnarBatches.compose(pidBatch, cb)
// Composed batch already hold pidBatch's shared ref, so close is safe.
ColumnarBatches.forceClose(pidBatch)
(0, ColumnarBatches.create(ColumnarBatches.getRuntime(cb), newHandle))
(
0,
ColumnarBatches
.create(Runtimes.contextInstance("ExecUtil#getShuffleDependency"), newHandle))
})
.recyclePayload(p => ColumnarBatches.forceClose(p._2)) // FIXME why force close?
.create()
Expand Down

0 comments on commit 6bc1e07

Please sign in to comment.