Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchengchenghh committed Jul 15, 2024
1 parent 0309dbb commit d8b1312
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, JoinedRow}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.execution.datasources.{FileFormat, HadoopFileLinesReader, OutputWriterFactory, PartitionedFile}
import org.apache.spark.sql.execution.datasources.csv.CSVDataSource
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.{DataSourceRegister, Filter}
import org.apache.spark.sql.types.{StructField, StructType}
Expand Down Expand Up @@ -310,16 +309,9 @@ object ArrowCSVFileFormat {
schema: StructType,
batchSize: Int,
it: Iterator[InternalRow]): Iterator[ColumnarBatch] = {
// note, these metrics are unused but just make `RowToVeloxColumnarExec` happy
val numInputRows = new SQLMetric("numInputRows")
val numOutputBatches = new SQLMetric("numOutputBatches")
val convertTime = new SQLMetric("convertTime")
val veloxBatch = RowToVeloxColumnarExec.toColumnarBatchIterator(
it,
schema,
numInputRows,
numOutputBatches,
convertTime,
batchSize
)
veloxBatch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,23 @@ case class RowToVeloxColumnarExec(child: SparkPlan) extends RowToColumnarExecBas
}

object RowToVeloxColumnarExec {

def toColumnarBatchIterator(
it: Iterator[InternalRow],
schema: StructType,
columnBatchSize: Int): Iterator[ColumnarBatch] = {
val numInputRows = new SQLMetric("numInputRows")
val numOutputBatches = new SQLMetric("numOutputBatches")
val convertTime = new SQLMetric("convertTime")
RowToVeloxColumnarExec.toColumnarBatchIterator(
it,
schema,
numInputRows,
numOutputBatches,
convertTime,
columnBatchSize)
}

def toColumnarBatchIterator(
it: Iterator[InternalRow],
schema: StructType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,22 @@ case class VeloxColumnarToRowExec(child: SparkPlan) extends ColumnarToRowExecBas
}

object VeloxColumnarToRowExec {

def toRowIterator(
batches: Iterator[ColumnarBatch],
output: Seq[Attribute]): Iterator[InternalRow] = {
val numOutputRows = new SQLMetric("numOutputRows")
val numInputBatches = new SQLMetric("numInputBatches")
val convertTime = new SQLMetric("convertTime")
toRowIterator(
batches,
output,
numOutputRows,
numInputBatches,
convertTime
)
}

def toRowIterator(
batches: Iterator[ColumnarBatch],
output: Seq[Attribute],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import org.apache.gluten.vectorized.ColumnarBatchSerializerJniWrapper

import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper}
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.columnar.{CachedBatch, CachedBatchSerializer}
Expand Down Expand Up @@ -134,22 +133,9 @@ class ColumnarCachedBatchSerializer extends CachedBatchSerializer with SQLConfHe
conf)
}

// note, these metrics are unused but just make `RowToVeloxColumnarExec` happy
val metrics = BackendsApiManager.getMetricsApiInstance.genRowToColumnarMetrics(
SparkSession.getActiveSession.orNull.sparkContext)
val numInputRows = metrics("numInputRows")
val numOutputBatches = metrics("numOutputBatches")
val convertTime = metrics("convertTime")
val numRows = conf.columnBatchSize
val rddColumnarBatch = input.mapPartitions {
it =>
RowToVeloxColumnarExec.toColumnarBatchIterator(
it,
localSchema,
numInputRows,
numOutputBatches,
convertTime,
numRows)
it => RowToVeloxColumnarExec.toColumnarBatchIterator(it, localSchema, numRows)
}
convertColumnarBatchToCachedBatch(rddColumnarBatch, schema, storageLevel, conf)
}
Expand All @@ -169,22 +155,10 @@ class ColumnarCachedBatchSerializer extends CachedBatchSerializer with SQLConfHe
conf)
}

// note, these metrics are unused but just make `VeloxColumnarToRowExec` happy
val metrics = BackendsApiManager.getMetricsApiInstance.genColumnarToRowMetrics(
SparkSession.getActiveSession.orNull.sparkContext)
val numOutputRows = metrics("numOutputRows")
val numInputBatches = metrics("numInputBatches")
val convertTime = metrics("convertTime")
val rddColumnarBatch =
convertCachedBatchToColumnarBatch(input, cacheAttributes, selectedAttributes, conf)
rddColumnarBatch.mapPartitions {
it =>
VeloxColumnarToRowExec.toRowIterator(
it,
selectedAttributes,
numOutputRows,
numInputBatches,
convertTime)
it => VeloxColumnarToRowExec.toRowIterator(it, selectedAttributes)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
*/
package org.apache.gluten.columnarbatch;

import org.apache.gluten.execution.RowToVeloxColumnarExec;
import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators;
import org.apache.gluten.test.VeloxBackendTestBase;
import org.apache.gluten.vectorized.ArrowWritableColumnVector;

import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.vectorized.ColumnarBatch;
import org.apache.spark.util.TaskResources$;
Expand All @@ -30,6 +32,8 @@
import java.util.Spliterators;
import java.util.stream.StreamSupport;

import scala.collection.JavaConverters;

public class ColumnarBatchTest extends VeloxBackendTestBase {

@Test
Expand Down Expand Up @@ -91,6 +95,25 @@ public void testCreateByHandle() {
});
}

@Test
public void testToString() {
TaskResources$.MODULE$.runUnsafe(
() -> {
final ColumnarBatch batch = newArrowBatch(20);
StructType structType = new StructType();
structType = structType.add("a", DataTypes.BooleanType, true);
structType = structType.add("b", DataTypes.IntegerType, true);
ColumnarBatch veloxBatch =
RowToVeloxColumnarExec.toColumnarBatchIterator(
JavaConverters.asScalaIterator(batch.rowIterator()), structType, 20)
.next();
Assert.assertEquals(
"0: {true, 15}\n1: {false, 14}", ColumnarBatches.toString(veloxBatch, 0, 2));
veloxBatch.close();
return null;
});
}

private static ColumnarBatch newArrowBatch(String schema, int numRows) {
final ArrowWritableColumnVector[] columns =
ArrowWritableColumnVector.allocateColumns(numRows, StructType.fromDDL(schema));
Expand All @@ -101,4 +124,22 @@ private static ColumnarBatch newArrowBatch(String schema, int numRows) {
batch.setNumRows(numRows);
return batch;
}

private static ColumnarBatch newArrowBatch(int numRows) {
String schema = "a boolean, b int";
final ArrowWritableColumnVector[] columns =
ArrowWritableColumnVector.allocateColumns(numRows, StructType.fromDDL(schema));
ArrowWritableColumnVector col1 = columns[0];
ArrowWritableColumnVector col2 = columns[1];
for (int j = 0; j < numRows; j++) {
col1.putBoolean(j, j % 2 == 0);
col2.putInt(j, 15 - j);
}
for (ArrowWritableColumnVector col : columns) {
col.setValueCount(numRows);
}
final ColumnarBatch batch = new ColumnarBatch(columns);
batch.setNumRows(numRows);
return batch;
}
}
4 changes: 3 additions & 1 deletion cpp/core/memory/ColumnarBatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ class ColumnarBatch {
// Serializes one single row to byte array that can be accessed as Spark-compatible unsafe row.
virtual std::vector<char> toUnsafeRow(int32_t rowId) const;

virtual std::string toString(int32_t start, int32_t length) const = 0;
virtual std::string toString(int32_t start, int32_t length) const {
throw GlutenException("Not implement");
}

friend std::ostream& operator<<(std::ostream& os, const ColumnarBatch& columnarBatch);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ private static void transferVectors(ColumnarBatch from, ColumnarBatch target) {
newVectors[i] = from.column(i);
}
FIELD_COLUMNS.set(target, newVectors);
System.out.println();
} catch (IllegalAccessException e) {
throw new GlutenException(e);
}
Expand Down Expand Up @@ -382,8 +381,7 @@ public static long getNativeHandle(ColumnarBatch batch) {
}

public static String toString(ColumnarBatch batch, int start, int length) {
return ColumnarBatchJniWrapper
.create(Runtimes.contextInstance("ColumnarBatches#toString"))
return ColumnarBatchJniWrapper.create(Runtimes.contextInstance("ColumnarBatches#toString"))
.toString(getNativeHandle(batch), start, length);
}
}

0 comments on commit d8b1312

Please sign in to comment.