Skip to content

Commit

Permalink
[VL] Code cleanup for Arrow CSV UTs (#7651)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhztheplayer authored Oct 24, 2024
1 parent 2937c66 commit b1741c6
Show file tree
Hide file tree
Showing 5 changed files with 231 additions and 254 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,20 @@
*/
package org.apache.gluten.execution.datasource.v2

import org.apache.gluten.columnarbatch.ArrowBatches
import org.apache.gluten.extension.GlutenPlan
import org.apache.gluten.extension.columnar.transition.Convention

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.connector.read.{Batch, PartitionReaderFactory, Scan}
import org.apache.spark.sql.execution.BaseArrowScanExec
import org.apache.spark.sql.execution.datasources.v2.{ArrowBatchScanExecShim, BatchScanExec}

case class ArrowBatchScanExec(original: BatchScanExec)
extends ArrowBatchScanExecShim(original)
with GlutenPlan {
with BaseArrowScanExec {

@transient lazy val batch: Batch = original.batch

override protected def batchType0(): Convention.BatchType = {
ArrowBatches.ArrowJavaBatch
}

override lazy val readerFactory: PartitionReaderFactory = original.readerFactory

override lazy val inputRDD: RDD[InternalRow] = original.inputRDD
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@
*/
package org.apache.spark.sql.execution

import org.apache.gluten.columnarbatch.ArrowBatches
import org.apache.gluten.extension.GlutenPlan
import org.apache.gluten.extension.columnar.transition.Convention

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
Expand All @@ -29,7 +25,7 @@ import scala.concurrent.duration.NANOSECONDS

case class ArrowFileSourceScanExec(original: FileSourceScanExec)
extends ArrowFileSourceScanLikeShim(original)
with GlutenPlan {
with BaseArrowScanExec {

lazy val inputRDD: RDD[InternalRow] = original.inputRDD

Expand All @@ -41,10 +37,6 @@ case class ArrowFileSourceScanExec(original: FileSourceScanExec)

override def doCanonicalize(): FileSourceScanExec = original.doCanonicalize()

override protected def batchType0(): Convention.BatchType = {
ArrowBatches.ArrowJavaBatch
}

override protected def doExecuteColumnar(): RDD[ColumnarBatch] = {
val numOutputRows = longMetric("numOutputRows")
val scanTime = longMetric("scanTime")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution

import org.apache.gluten.columnarbatch.ArrowBatches
import org.apache.gluten.extension.GlutenPlan
import org.apache.gluten.extension.columnar.transition.Convention

trait BaseArrowScanExec extends GlutenPlan {

final override protected def batchType0(): Convention.BatchType = {
ArrowBatches.ArrowJavaBatch
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.execution

import org.apache.gluten.GlutenConfig
import org.apache.gluten.datasource.ArrowCSVFileFormat

import org.apache.spark.SparkConf
import org.apache.spark.sql.execution.{ArrowFileSourceScanExec, BaseArrowScanExec, ColumnarToRowExec}
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}

class ArrowCsvScanSuiteV1 extends ArrowCsvScanSuite {
override protected def sparkConf: SparkConf = {
super.sparkConf
.set("spark.sql.sources.useV1SourceList", "csv")
}

test("csv scan v1") {
val df = runAndCompare("select * from student")()
val plan = df.queryExecution.executedPlan
assert(plan.find(s => s.isInstanceOf[ColumnarToRowExec]).isDefined)
assert(plan.find(_.isInstanceOf[BaseArrowScanExec]).isDefined)
val scan = plan.find(_.isInstanceOf[BaseArrowScanExec]).toList.head
assert(
scan
.asInstanceOf[ArrowFileSourceScanExec]
.relation
.fileFormat
.isInstanceOf[ArrowCSVFileFormat])
}

test("csv scan with schema v1") {
val df = runAndCompare("select * from student_option_schema")()
val plan = df.queryExecution.executedPlan
assert(plan.find(s => s.isInstanceOf[ColumnarToRowExec]).isDefined)
val scan = plan.find(_.isInstanceOf[BaseArrowScanExec])
assert(scan.isDefined)
assert(
!scan.get
.asInstanceOf[ArrowFileSourceScanExec]
.original
.relation
.fileFormat
.asInstanceOf[ArrowCSVFileFormat]
.fallback)
}
}

class ArrowCsvScanSuiteV2 extends ArrowCsvScanSuite {
override protected def sparkConf: SparkConf = {
super.sparkConf
.set("spark.sql.sources.useV1SourceList", "")
}
}

/** Since https://github.com/apache/incubator-gluten/pull/5850. */
abstract class ArrowCsvScanSuite extends VeloxWholeStageTransformerSuite {
override protected val resourcePath: String = "N/A"
override protected val fileFormat: String = "N/A"

protected val rootPath: String = getClass.getResource("/").getPath

override def beforeAll(): Unit = {
super.beforeAll()
createCsvTables()
}

override def afterAll(): Unit = {
super.afterAll()
}

override protected def sparkConf: SparkConf = {
super.sparkConf
.set("spark.shuffle.manager", "org.apache.spark.shuffle.sort.ColumnarShuffleManager")
.set("spark.sql.files.maxPartitionBytes", "1g")
.set("spark.sql.shuffle.partitions", "1")
.set("spark.memory.offHeap.size", "2g")
.set("spark.unsafe.exceptionOnMemoryLeak", "true")
.set("spark.sql.autoBroadcastJoinThreshold", "-1")
.set(GlutenConfig.NATIVE_ARROW_READER_ENABLED.key, "true")
}

test("csv scan with option string as null") {
val df = runAndCompare("select * from student_option_str")()
val plan = df.queryExecution.executedPlan
assert(plan.find(_.isInstanceOf[ColumnarToRowExec]).isDefined)
assert(plan.find(_.isInstanceOf[BaseArrowScanExec]).isDefined)
}

test("csv scan with option delimiter") {
val df = runAndCompare("select * from student_option")()
val plan = df.queryExecution.executedPlan
assert(plan.find(s => s.isInstanceOf[ColumnarToRowExec]).isDefined)
assert(plan.find(_.isInstanceOf[BaseArrowScanExec]).isDefined)
}

test("csv scan with missing columns") {
val df =
runAndCompare("select languagemissing, language, id_new_col from student_option_schema_lm")()
val plan = df.queryExecution.executedPlan
assert(plan.find(s => s.isInstanceOf[VeloxColumnarToRowExec]).isDefined)
assert(plan.find(_.isInstanceOf[BaseArrowScanExec]).isDefined)
}

test("csv scan with different name") {
val df = runAndCompare("select * from student_option_schema")()
val plan = df.queryExecution.executedPlan
assert(plan.find(s => s.isInstanceOf[ColumnarToRowExec]).isDefined)
assert(plan.find(_.isInstanceOf[BaseArrowScanExec]).isDefined)

val df2 = runAndCompare("select * from student_option_schema")()
val plan2 = df2.queryExecution.executedPlan
assert(plan2.find(s => s.isInstanceOf[ColumnarToRowExec]).isDefined)
assert(plan2.find(_.isInstanceOf[BaseArrowScanExec]).isDefined)
}

test("csv scan with filter") {
val df = runAndCompare("select * from student where Name = 'Peter'")()
assert(df.queryExecution.executedPlan.find(s => s.isInstanceOf[ColumnarToRowExec]).isEmpty)
assert(
df.queryExecution.executedPlan
.find(s => s.isInstanceOf[BaseArrowScanExec])
.isDefined)
}

test("insert into select from csv") {
withTable("insert_csv_t") {
spark.sql("create table insert_csv_t(Name string, Language string) using parquet;")
runQueryAndCompare("""
|insert into insert_csv_t select * from student;
|""".stripMargin) {
checkGlutenOperatorMatch[BaseArrowScanExec]
}
}
}

test("count(1) on csv scan") {
val df = runAndCompare("select count(1) from student")()
checkLengthAndPlan(df, 1)
}

private def createCsvTables(): Unit = {
spark.read
.format("csv")
.option("header", "true")
.load(rootPath + "/datasource/csv/student.csv")
.createOrReplaceTempView("student")

spark.read
.format("csv")
.option("header", "true")
.load(rootPath + "/datasource/csv/student_option_str.csv")
.createOrReplaceTempView("student_option_str")

spark.read
.format("csv")
.option("header", "true")
.option("delimiter", ";")
.load(rootPath + "/datasource/csv/student_option.csv")
.createOrReplaceTempView("student_option")

spark.read
.schema(
new StructType()
.add("id", StringType)
.add("name", StringType)
.add("language", StringType))
.format("csv")
.option("header", "true")
.load(rootPath + "/datasource/csv/student_option_schema.csv")
.createOrReplaceTempView("student_option_schema")

spark.read
.schema(
new StructType()
.add("id_new_col", IntegerType)
.add("name", StringType)
.add("language", StringType)
.add("languagemissing", StringType))
.format("csv")
.option("header", "true")
.load(rootPath + "/datasource/csv/student_option_schema.csv")
.createOrReplaceTempView("student_option_schema_lm")
}
}
Loading

0 comments on commit b1741c6

Please sign in to comment.