Skip to content

Commit

Permalink
feat: Hook DataFusion Parquet native scan with Comet execution (#1094)
Browse files Browse the repository at this point in the history
* init

* more

* fix

* more

* more

* fix
  • Loading branch information
viirya authored Nov 19, 2024
1 parent 8a0df9d commit 1cca8d6
Show file tree
Hide file tree
Showing 7 changed files with 166 additions and 463 deletions.
18 changes: 8 additions & 10 deletions native/core/src/execution/datafusion/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1023,17 +1023,15 @@ impl PhysicalPlanner {
.with_file_groups(file_groups);

// Check for projection, if so generate the vector and add to FileScanConfig.
if !required_schema_arrow.fields.is_empty() {
let mut projection_vector: Vec<usize> =
Vec::with_capacity(required_schema_arrow.fields.len());
// TODO: could be faster with a hashmap rather than iterating over data_schema_arrow with index_of.
required_schema_arrow.fields.iter().for_each(|field| {
projection_vector.push(data_schema_arrow.index_of(field.name()).unwrap());
});
let mut projection_vector: Vec<usize> =
Vec::with_capacity(required_schema_arrow.fields.len());
// TODO: could be faster with a hashmap rather than iterating over data_schema_arrow with index_of.
required_schema_arrow.fields.iter().for_each(|field| {
projection_vector.push(data_schema_arrow.index_of(field.name()).unwrap());
});

assert_eq!(projection_vector.len(), required_schema_arrow.fields.len());
file_scan_config = file_scan_config.with_projection(Some(projection_vector));
}
assert_eq!(projection_vector.len(), required_schema_arrow.fields.len());
file_scan_config = file_scan_config.with_projection(Some(projection_vector));

let mut table_parquet_options = TableParquetOptions::new();
// TODO: Maybe these are configs?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,9 @@ class CometSparkSessionExtensions
if CometNativeScanExec.isSchemaSupported(requiredSchema)
&& CometNativeScanExec.isSchemaSupported(partitionSchema)
&& COMET_FULL_NATIVE_SCAN_ENABLED.get =>
logInfo("Comet extension enabled for v1 Scan")
CometNativeScanExec(scanExec, session)
logInfo("Comet extension enabled for v1 full native Scan")
CometScanExec(scanExec, session)

// data source V1
case scanExec @ FileSourceScanExec(
HadoopFsRelation(_, partitionSchema, _, _, _: ParquetFileFormat, _),
Expand Down Expand Up @@ -365,6 +366,12 @@ class CometSparkSessionExtensions
}

plan.transformUp {
// Fully native scan for V1
case scan: CometScanExec if COMET_FULL_NATIVE_SCAN_ENABLED.get =>
val nativeOp = QueryPlanSerde.operator2Proto(scan).get
CometNativeScanExec(nativeOp, scan.wrapped, scan.session)

// Comet JVM + native scan for V1 and V2
case op if isCometScan(op) =>
val nativeOp = QueryPlanSerde.operator2Proto(op).get
CometScanWrapper(nativeOp, op)
Expand Down Expand Up @@ -1221,8 +1228,7 @@ object CometSparkSessionExtensions extends Logging {
}

def isCometScan(op: SparkPlan): Boolean = {
op.isInstanceOf[CometBatchScanExec] || op.isInstanceOf[CometScanExec] ||
op.isInstanceOf[CometNativeScanExec]
op.isInstanceOf[CometBatchScanExec] || op.isInstanceOf[CometScanExec]
}

private def shouldApplySparkToColumnar(conf: SQLConf, op: SparkPlan): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, Normalize
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning, SinglePartition}
import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometNativeScanExec, CometSinkPlaceHolder, CometSparkToColumnarExec, DecimalPrecision}
import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, CometNativeScanExec, CometScanExec, CometSinkPlaceHolder, CometSparkToColumnarExec, DecimalPrecision}
import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution._
Expand Down Expand Up @@ -2481,7 +2481,9 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
childOp.foreach(result.addChildren)

op match {
case scan: CometNativeScanExec =>

// Fully native scan for V1
case scan: CometScanExec if CometConf.COMET_FULL_NATIVE_SCAN_ENABLED.get =>
val nativeScanBuilder = OperatorOuterClass.NativeScan.newBuilder()
nativeScanBuilder.setSource(op.simpleStringWithNodeId())

Expand Down
56 changes: 56 additions & 0 deletions spark/src/main/scala/org/apache/spark/sql/comet/CometExecRDD.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.spark.sql.comet

import org.apache.spark.{Partition, SparkContext, TaskContext}
import org.apache.spark.rdd.{RDD, RDDOperationScope}
import org.apache.spark.sql.vectorized.ColumnarBatch

/**
* A RDD that executes Spark SQL query in Comet native execution to generate ColumnarBatch.
*/
private[spark] class CometExecRDD(
sc: SparkContext,
partitionNum: Int,
var f: (Seq[Iterator[ColumnarBatch]]) => Iterator[ColumnarBatch])
extends RDD[ColumnarBatch](sc, Nil) {

override def compute(s: Partition, context: TaskContext): Iterator[ColumnarBatch] = {
f(Seq.empty)
}

override protected def getPartitions: Array[Partition] = {
Array.tabulate(partitionNum)(i =>
new Partition {
override def index: Int = i
})
}
}

object CometExecRDD {
def apply(sc: SparkContext, partitionNum: Int)(
f: (Seq[Iterator[ColumnarBatch]]) => Iterator[ColumnarBatch]): RDD[ColumnarBatch] =
withScope(sc) {
new CometExecRDD(sc, partitionNum, f)
}

private[spark] def withScope[U](sc: SparkContext)(body: => U): U =
RDDOperationScope.withScope[U](sc)(body)
}
Loading

0 comments on commit 1cca8d6

Please sign in to comment.