diff --git a/.github/workflows/build_bundle_package.yml b/.github/workflows/build_bundle_package.yml index 01ddd6f43857..8ce659366770 100644 --- a/.github/workflows/build_bundle_package.yml +++ b/.github/workflows/build_bundle_package.yml @@ -38,7 +38,7 @@ on: jobs: build-native-lib: runs-on: ubuntu-20.04 - container: apache/gluten:gluten-vcpkg-builder_2024_03_17 + container: apache/gluten:gluten-vcpkg-builder_2024_05_29 steps: - uses: actions/checkout@v2 - name: Build Gluten velox third party @@ -53,11 +53,17 @@ jobs: export NUM_THREADS=4 ./dev/builddeps-veloxbe.sh --build_tests=OFF --build_benchmarks=OFF --enable_s3=OFF \ --enable_gcs=OFF --enable_hdfs=ON --enable_abfs=OFF - - uses: actions/upload-artifact@v2 + - name: Upload native libs + uses: actions/upload-artifact@v2 with: path: ./cpp/build/releases/ name: velox-native-lib-${{github.sha}} retention-days: 1 + - name: Upload Artifact Arrow Jar + uses: actions/upload-artifact@v2 + with: + path: /root/.m2/repository/org/apache/arrow/ + name: velox-arrow-jar-centos-7-${{github.sha}} build-bundle-package-ubuntu: if: startsWith(github.event.inputs.os, 'ubuntu') @@ -71,6 +77,11 @@ jobs: with: name: velox-native-lib-${{github.sha}} path: ./cpp/build/releases + - name: Download All Arrow Jar Artifacts + uses: actions/download-artifact@v2 + with: + name: velox-arrow-jar-centos-7-${{github.sha}} + path: /root/.m2/repository/org/apache/arrow/ - name: Setup java and maven run: | apt-get update && \ @@ -99,6 +110,11 @@ jobs: with: name: velox-native-lib-${{github.sha}} path: ./cpp/build/releases + - name: Download All Arrow Jar Artifacts + uses: actions/download-artifact@v2 + with: + name: velox-arrow-jar-centos-7-${{github.sha}} + path: /root/.m2/repository/org/apache/arrow/ - name: Setup java and maven run: | yum update -y && yum install -y java-1.8.0-openjdk-devel wget && \ @@ -130,6 +146,11 @@ jobs: with: name: velox-native-lib-${{github.sha}} path: ./cpp/build/releases + - name: Download All Arrow Jar Artifacts + uses: actions/download-artifact@v2 + with: + name: velox-arrow-jar-centos-7-${{github.sha}} + path: /root/.m2/repository/org/apache/arrow/ - name: Update mirror list run: | sed -i -e "s|mirrorlist=|#mirrorlist=|g" /etc/yum.repos.d/CentOS-* || true && \ diff --git a/.github/workflows/velox_docker.yml b/.github/workflows/velox_docker.yml index 5f64c9f7e0e8..d110d0a6d223 100644 --- a/.github/workflows/velox_docker.yml +++ b/.github/workflows/velox_docker.yml @@ -120,6 +120,12 @@ jobs: with: name: velox-arrow-jar-centos-7-${{github.sha}} path: /root/.m2/repository/org/apache/arrow/ + - name: Setup tzdata + run: | + if [ "${{ matrix.os }}" = "ubuntu:22.04" ]; then + apt-get update + TZ="Etc/GMT" DEBIAN_FRONTEND=noninteractive apt-get install -y tzdata + fi - name: Setup java and maven run: | if [ "${{ matrix.java }}" = "java-17" ]; then @@ -515,7 +521,7 @@ jobs: fail-fast: false matrix: spark: ["spark-3.2"] - celeborn: ["celeborn-0.4.0", "celeborn-0.3.2"] + celeborn: ["celeborn-0.4.1", "celeborn-0.3.2-incubating"] runs-on: ubuntu-20.04 container: ubuntu:22.04 steps: @@ -530,6 +536,10 @@ jobs: with: name: velox-arrow-jar-centos-7-${{github.sha}} path: /root/.m2/repository/org/apache/arrow/ + - name: Setup tzdata + run: | + apt-get update + TZ="Etc/GMT" DEBIAN_FRONTEND=noninteractive apt-get install -y tzdata - name: Setup java and maven run: | apt-get update && apt-get install -y openjdk-8-jdk maven wget @@ -547,8 +557,8 @@ jobs: fi echo "EXTRA_PROFILE: ${EXTRA_PROFILE}" cd /opt && mkdir -p celeborn && \ - wget https://archive.apache.org/dist/incubator/celeborn/${{ matrix.celeborn }}-incubating/apache-${{ matrix.celeborn }}-incubating-bin.tgz && \ - tar xzf apache-${{ matrix.celeborn }}-incubating-bin.tgz -C /opt/celeborn --strip-components=1 && cd celeborn && \ + wget https://archive.apache.org/dist/celeborn/${{ matrix.celeborn }}/apache-${{ matrix.celeborn }}-bin.tgz && \ + tar xzf apache-${{ matrix.celeborn }}-bin.tgz -C /opt/celeborn --strip-components=1 && cd celeborn && \ mv ./conf/celeborn-env.sh.template ./conf/celeborn-env.sh && \ bash -c "echo -e 'CELEBORN_MASTER_MEMORY=4g\nCELEBORN_WORKER_MEMORY=4g\nCELEBORN_WORKER_OFFHEAP_MEMORY=8g' > ./conf/celeborn-env.sh" && \ bash -c "echo -e 'celeborn.worker.commitFiles.threads 128\nceleborn.worker.sortPartition.threads 64' > ./conf/celeborn-defaults.conf" && \ diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala index 941237629569..376e46ebe975 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala @@ -16,7 +16,7 @@ */ package org.apache.gluten.backendsapi.clickhouse -import org.apache.gluten.{GlutenConfig, GlutenNumaBindingInfo} +import org.apache.gluten.GlutenNumaBindingInfo import org.apache.gluten.backendsapi.IteratorApi import org.apache.gluten.execution._ import org.apache.gluten.expression.ConverterUtils @@ -61,6 +61,52 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil { StructType(dataSchema) } + private def createNativeIterator( + splitInfoByteArray: Array[Array[Byte]], + wsPlan: Array[Byte], + materializeInput: Boolean, + inputIterators: Seq[Iterator[ColumnarBatch]]): BatchIterator = { + + /** Generate closeable ColumnBatch iterator. */ + val listIterator = + inputIterators + .map { + case i: CloseableCHColumnBatchIterator => i + case it => new CloseableCHColumnBatchIterator(it) + } + .map(it => new ColumnarNativeIterator(it.asJava).asInstanceOf[GeneralInIterator]) + .asJava + new CHNativeExpressionEvaluator().createKernelWithBatchIterator( + wsPlan, + splitInfoByteArray, + listIterator, + materializeInput + ) + } + + private def createCloseIterator( + context: TaskContext, + pipelineTime: SQLMetric, + updateNativeMetrics: IMetrics => Unit, + updateInputMetrics: Option[InputMetricsWrapper => Unit] = None, + nativeIter: BatchIterator): CloseableCHColumnBatchIterator = { + + val iter = new CollectMetricIterator( + nativeIter, + updateNativeMetrics, + updateInputMetrics, + updateInputMetrics.map(_ => context.taskMetrics().inputMetrics).orNull) + + context.addTaskFailureListener( + (ctx, _) => { + if (ctx.isInterrupted()) { + iter.cancel() + } + }) + context.addTaskCompletionListener[Unit](_ => iter.close()) + new CloseableCHColumnBatchIterator(iter, Some(pipelineTime)) + } + // only set file schema for text format table private def setFileSchemaForLocalFiles( localFilesNode: LocalFilesNode, @@ -198,45 +244,24 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil { inputIterators: Seq[Iterator[ColumnarBatch]] = Seq() ): Iterator[ColumnarBatch] = { - assert( + require( inputPartition.isInstanceOf[GlutenPartition], "CH backend only accepts GlutenPartition in GlutenWholeStageColumnarRDD.") - - val transKernel = new CHNativeExpressionEvaluator() - val inBatchIters = new JArrayList[GeneralInIterator](inputIterators.map { - iter => new ColumnarNativeIterator(CHIteratorApi.genCloseableColumnBatchIterator(iter).asJava) - }.asJava) - val splitInfoByteArray = inputPartition .asInstanceOf[GlutenPartition] .splitInfosByteArray - val nativeIter = - transKernel.createKernelWithBatchIterator( - inputPartition.plan, - splitInfoByteArray, - inBatchIters, - false) + val wsPlan = inputPartition.plan + val materializeInput = false - val iter = new CollectMetricIterator( - nativeIter, - updateNativeMetrics, - updateInputMetrics, - context.taskMetrics().inputMetrics) - - context.addTaskFailureListener( - (ctx, _) => { - if (ctx.isInterrupted()) { - iter.cancel() - } - }) - context.addTaskCompletionListener[Unit](_ => iter.close()) - - // TODO: SPARK-25083 remove the type erasure hack in data source scan new InterruptibleIterator( context, - new CloseableCHColumnBatchIterator( - iter.asInstanceOf[Iterator[ColumnarBatch]], - Some(pipelineTime))) + createCloseIterator( + context, + pipelineTime, + updateNativeMetrics, + Some(updateInputMetrics), + createNativeIterator(splitInfoByteArray, wsPlan, materializeInput, inputIterators)) + ) } // Generate Iterator[ColumnarBatch] for final stage. @@ -252,52 +277,26 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil { partitionIndex: Int, materializeInput: Boolean): Iterator[ColumnarBatch] = { // scalastyle:on argcount - GlutenConfig.getConf - - val transKernel = new CHNativeExpressionEvaluator() - val columnarNativeIterator = - new JArrayList[GeneralInIterator](inputIterators.map { - iter => - new ColumnarNativeIterator(CHIteratorApi.genCloseableColumnBatchIterator(iter).asJava) - }.asJava) - // we need to complete dependency RDD's firstly - val nativeIterator = transKernel.createKernelWithBatchIterator( - rootNode.toProtobuf.toByteArray, - // Final iterator does not contain scan split, so pass empty split info to native here. - new Array[Array[Byte]](0), - columnarNativeIterator, - materializeInput - ) - - val iter = new CollectMetricIterator(nativeIterator, updateNativeMetrics, null, null) - context.addTaskFailureListener( - (ctx, _) => { - if (ctx.isInterrupted()) { - iter.cancel() - } - }) - context.addTaskCompletionListener[Unit](_ => iter.close()) - new CloseableCHColumnBatchIterator(iter, Some(pipelineTime)) - } -} + // Final iterator does not contain scan split, so pass empty split info to native here. + val splitInfoByteArray = new Array[Array[Byte]](0) + val wsPlan = rootNode.toProtobuf.toByteArray -object CHIteratorApi { - - /** Generate closeable ColumnBatch iterator. */ - def genCloseableColumnBatchIterator(iter: Iterator[ColumnarBatch]): Iterator[ColumnarBatch] = { - iter match { - case _: CloseableCHColumnBatchIterator => iter - case _ => new CloseableCHColumnBatchIterator(iter) - } + // we need to complete dependency RDD's firstly + createCloseIterator( + context, + pipelineTime, + updateNativeMetrics, + None, + createNativeIterator(splitInfoByteArray, wsPlan, materializeInput, inputIterators)) } } class CollectMetricIterator( val nativeIterator: BatchIterator, val updateNativeMetrics: IMetrics => Unit, - val updateInputMetrics: InputMetricsWrapper => Unit, - val inputMetrics: InputMetrics + val updateInputMetrics: Option[InputMetricsWrapper => Unit] = None, + val inputMetrics: InputMetrics = null ) extends Iterator[ColumnarBatch] { private var outputRowCount = 0L private var outputVectorCount = 0L @@ -329,9 +328,7 @@ class CollectMetricIterator( val nativeMetrics = nativeIterator.getMetrics.asInstanceOf[NativeMetrics] nativeMetrics.setFinalOutputMetrics(outputRowCount, outputVectorCount) updateNativeMetrics(nativeMetrics) - if (updateInputMetrics != null) { - updateInputMetrics(inputMetrics) - } + updateInputMetrics.foreach(_(inputMetrics)) metricsUpdated = true } } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala index 1c83e326eed4..ac3ea61ff810 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -50,7 +50,6 @@ import org.apache.spark.sql.delta.files.TahoeFileIndex import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec import org.apache.spark.sql.execution.datasources.{FileFormat, HadoopFsRelation} -import org.apache.spark.sql.execution.datasources.GlutenWriterColumnarRules.NativeWritePostRule import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.datasources.v2.clickhouse.source.DeltaMergeTreeFileFormat import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} @@ -583,14 +582,6 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { override def genExtendedColumnarTransformRules(): List[SparkSession => Rule[SparkPlan]] = List() - /** - * Generate extended columnar post-rules. - * - * @return - */ - override def genExtendedColumnarPostRules(): List[SparkSession => Rule[SparkPlan]] = - List(spark => NativeWritePostRule(spark)) - override def genInjectPostHocResolutionRules(): List[SparkSession => Rule[LogicalPlan]] = { List() } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala index a7e7769e7736..da9d9c7586c0 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala @@ -16,7 +16,6 @@ */ package org.apache.gluten.execution -import org.apache.gluten.backendsapi.clickhouse.CHIteratorApi import org.apache.gluten.extension.ValidationResult import org.apache.gluten.utils.{BroadcastHashJoinStrategy, CHJoinValidateUtil, ShuffleHashJoinStrategy} @@ -75,7 +74,7 @@ case class CHBroadcastBuildSideRDD( override def genBroadcastBuildSideIterator(): Iterator[ColumnarBatch] = { CHBroadcastBuildSideCache.getOrBuildBroadcastHashTable(broadcasted, broadcastContext) - CHIteratorApi.genCloseableColumnBatchIterator(Iterator.empty) + Iterator.empty } } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala index cf45c1118f13..e9bee84396f8 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/utils/CHExpressionUtil.scala @@ -209,7 +209,6 @@ object CHExpressionUtil { UNIX_MICROS -> DefaultValidator(), TIMESTAMP_MILLIS -> DefaultValidator(), TIMESTAMP_MICROS -> DefaultValidator(), - FLATTEN -> DefaultValidator(), STACK -> DefaultValidator() ) } diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala index 088487101081..7320b7c05152 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseDecimalSuite.scala @@ -67,9 +67,9 @@ class GlutenClickHouseDecimalSuite private val decimalTPCHTables: Seq[(DecimalType, Seq[Int])] = Seq.apply( (DecimalType.apply(9, 4), Seq()), // 1: ch decimal avg is float - (DecimalType.apply(18, 8), Seq(1)), + (DecimalType.apply(18, 8), Seq()), // 1: ch decimal avg is float, 3/10: all value is null and compare with limit - (DecimalType.apply(38, 19), Seq(1, 3, 10)) + (DecimalType.apply(38, 19), Seq(3, 10)) ) private def createDecimalTables(dataType: DecimalType): Unit = { @@ -337,7 +337,6 @@ class GlutenClickHouseDecimalSuite allowPrecisionLoss => Range .inclusive(1, 22) - .filter(_ != 17) // Ignore Q17 which include avg .foreach { sql_num => { diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseMergeTreeWriteOnHDFSSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseMergeTreeWriteOnHDFSSuite.scala index 572d0cd50a6e..99b212059966 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseMergeTreeWriteOnHDFSSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseMergeTreeWriteOnHDFSSuite.scala @@ -25,10 +25,12 @@ import org.apache.spark.sql.execution.datasources.v2.clickhouse.metadata.AddMerg import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.fs.{FileSystem, Path} import java.io.File +import scala.concurrent.duration.DurationInt + // Some sqls' line length exceeds 100 // scalastyle:off line.size.limit @@ -614,5 +616,45 @@ class GlutenClickHouseMergeTreeWriteOnHDFSSuite .count() assertResult(600572)(result) } + + test("test mergetree insert with optimize basic") { + val tableName = "lineitem_mergetree_insert_optimize_basic_hdfs" + val dataPath = s"$HDFS_URL/test/$tableName" + + withSQLConf( + "spark.databricks.delta.optimize.minFileSize" -> "200000000", + "spark.gluten.sql.columnar.backend.ch.runtime_settings.mergetree.merge_after_insert" -> "true", + "spark.gluten.sql.columnar.backend.ch.runtime_settings.mergetree.insert_without_local_storage" -> "true", + "spark.gluten.sql.columnar.backend.ch.runtime_settings.min_insert_block_size_rows" -> "10000" + ) { + spark.sql(s""" + |DROP TABLE IF EXISTS $tableName; + |""".stripMargin) + + spark.sql(s""" + |CREATE TABLE IF NOT EXISTS $tableName + |USING clickhouse + |LOCATION '$dataPath' + |TBLPROPERTIES (storage_policy='__hdfs_main') + | as select * from lineitem + |""".stripMargin) + + val ret = spark.sql(s"select count(*) from $tableName").collect() + assertResult(600572)(ret.apply(0).get(0)) + val conf = new Configuration + conf.set("fs.defaultFS", HDFS_URL) + val fs = FileSystem.get(conf) + + eventually(timeout(60.seconds), interval(2.seconds)) { + val it = fs.listFiles(new Path(dataPath), true) + var files = 0 + while (it.hasNext) { + it.next() + files += 1 + } + assertResult(72)(files) + } + } + } } // scalastyle:off line.size.limit diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala index 9269303d9251..ccf7bb5d5b2a 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala @@ -21,6 +21,7 @@ import org.apache.gluten.execution.AllDataTypesWithComplexType.genTestData import org.apache.gluten.utils.UTSystemParameters import org.apache.spark.SparkConf +import org.apache.spark.gluten.NativeWriteChecker import org.apache.spark.sql.SparkSession import org.apache.spark.sql.delta.DeltaLog import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -28,11 +29,14 @@ import org.apache.spark.sql.test.SharedSparkSession import org.scalatest.BeforeAndAfterAll +import scala.reflect.runtime.universe.TypeTag + class GlutenClickHouseNativeWriteTableSuite extends GlutenClickHouseWholeStageTransformerSuite with AdaptiveSparkPlanHelper with SharedSparkSession - with BeforeAndAfterAll { + with BeforeAndAfterAll + with NativeWriteChecker { private var _hiveSpark: SparkSession = _ @@ -114,16 +118,19 @@ class GlutenClickHouseNativeWriteTableSuite def getColumnName(s: String): String = { s.replaceAll("\\(", "_").replaceAll("\\)", "_") } + import collection.immutable.ListMap import java.io.File def writeIntoNewTableWithSql(table_name: String, table_create_sql: String)( fields: Seq[String]): Unit = { - spark.sql(table_create_sql) - spark.sql( - s"insert overwrite $table_name select ${fields.mkString(",")}" + - s" from origin_table") + withDestinationTable(table_name, table_create_sql) { + checkNativeWrite( + s"insert overwrite $table_name select ${fields.mkString(",")}" + + s" from origin_table", + checkNative = true) + } } def writeAndCheckRead( @@ -170,82 +177,86 @@ class GlutenClickHouseNativeWriteTableSuite }) } - test("test insert into dir") { - withSQLConf( - ("spark.gluten.sql.native.writer.enabled", "true"), - (GlutenConfig.GLUTEN_ENABLED.key, "true")) { - - val originDF = spark.createDataFrame(genTestData()) - originDF.createOrReplaceTempView("origin_table") + private val fields_ = ListMap( + ("string_field", "string"), + ("int_field", "int"), + ("long_field", "long"), + ("float_field", "float"), + ("double_field", "double"), + ("short_field", "short"), + ("byte_field", "byte"), + ("boolean_field", "boolean"), + ("decimal_field", "decimal(23,12)"), + ("date_field", "date") + ) - val fields: ListMap[String, String] = ListMap( - ("string_field", "string"), - ("int_field", "int"), - ("long_field", "long"), - ("float_field", "float"), - ("double_field", "double"), - ("short_field", "short"), - ("byte_field", "byte"), - ("boolean_field", "boolean"), - ("decimal_field", "decimal(23,12)"), - ("date_field", "date") - ) + def withDestinationTable(table: String, createTableSql: String)(f: => Unit): Unit = { + spark.sql(s"drop table IF EXISTS $table") + spark.sql(s"$createTableSql") + f + } - for (format <- formats) { - spark.sql( - s"insert overwrite local directory '$basePath/test_insert_into_${format}_dir1' " - + s"stored as $format select " - + fields.keys.mkString(",") + - " from origin_table cluster by (byte_field)") - spark.sql( - s"insert overwrite local directory '$basePath/test_insert_into_${format}_dir2' " + - s"stored as $format " + - "select string_field, sum(int_field) as x from origin_table group by string_field") - } + def nativeWrite(f: String => Unit): Unit = { + withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) { + formats.foreach(f(_)) } } - test("test insert into partition") { - withSQLConf( - ("spark.gluten.sql.native.writer.enabled", "true"), - ("spark.sql.orc.compression.codec", "lz4"), - (GlutenConfig.GLUTEN_ENABLED.key, "true")) { - - val originDF = spark.createDataFrame(genTestData()) - originDF.createOrReplaceTempView("origin_table") - - val fields: ListMap[String, String] = ListMap( - ("string_field", "string"), - ("int_field", "int"), - ("long_field", "long"), - ("float_field", "float"), - ("double_field", "double"), - ("short_field", "short"), - ("byte_field", "byte"), - ("boolean_field", "boolean"), - ("decimal_field", "decimal(23,12)"), - ("date_field", "date") - ) - - for (format <- formats) { - val table_name = table_name_template.format(format) - spark.sql(s"drop table IF EXISTS $table_name") + def nativeWrite2( + f: String => (String, String, String), + extraCheck: (String, String, String) => Unit = null): Unit = nativeWrite { + format => + val (table_name, table_create_sql, insert_sql) = f(format) + withDestinationTable(table_name, table_create_sql) { + checkNativeWrite(insert_sql, checkNative = true) + Option(extraCheck).foreach(_(table_name, table_create_sql, insert_sql)) + } + } - val table_create_sql = - s"create table if not exists $table_name (" + - fields - .map(f => s"${f._1} ${f._2}") - .mkString(",") + - " ) partitioned by (another_date_field date) " + - s"stored as $format" + def nativeWriteWithOriginalView[A <: Product: TypeTag]( + data: Seq[A], + viewName: String, + pairs: (String, String)*)(f: String => Unit): Unit = { + val configs = pairs :+ ("spark.gluten.sql.native.writer.enabled", "true") + withSQLConf(configs: _*) { + withTempView(viewName) { + spark.createDataFrame(data).createOrReplaceTempView(viewName) + formats.foreach(f(_)) + } + } + } - spark.sql(table_create_sql) + test("test insert into dir") { + nativeWriteWithOriginalView(genTestData(), "origin_table") { + format => + Seq( + s"""insert overwrite local directory '$basePath/test_insert_into_${format}_dir1' + |stored as $format select ${fields_.keys.mkString(",")} + |from origin_table""".stripMargin, + s"""insert overwrite local directory '$basePath/test_insert_into_${format}_dir2' + |stored as $format select string_field, sum(int_field) as x + |from origin_table group by string_field""".stripMargin + ).foreach(checkNativeWrite(_, checkNative = true)) + } + } - spark.sql( - s"insert into $table_name partition(another_date_field = '2020-01-01') select " - + fields.keys.mkString(",") + - " from origin_table") + test("test insert into partition") { + def destination(format: String): (String, String, String) = { + val table_name = table_name_template.format(format) + val table_create_sql = + s"""create table if not exists $table_name + |(${fields_.map(f => s"${f._1} ${f._2}").mkString(",")}) + |partitioned by (another_date_field date) stored as $format""".stripMargin + val insert_sql = + s"""insert into $table_name partition(another_date_field = '2020-01-01') + | select ${fields_.keys.mkString(",")} from origin_table""".stripMargin + (table_name, table_create_sql, insert_sql) + } + def nativeFormatWrite(format: String): Unit = { + val (table_name, table_create_sql, insert_sql) = destination(format) + withDestinationTable(table_name, table_create_sql) { + checkNativeWrite(insert_sql, checkNative = true) var files = recursiveListFiles(new File(getWarehouseDir + "/" + table_name)) .filter(_.getName.endsWith(s".$format")) if (format == "orc") { @@ -255,154 +266,103 @@ class GlutenClickHouseNativeWriteTableSuite assert(files.head.getAbsolutePath.contains("another_date_field=2020-01-01")) } } + + nativeWriteWithOriginalView( + genTestData(), + "origin_table", + ("spark.sql.orc.compression.codec", "lz4"))(nativeFormatWrite) } test("test CTAS") { - withSQLConf( - ("spark.gluten.sql.native.writer.enabled", "true"), - (GlutenConfig.GLUTEN_ENABLED.key, "true")) { - - val originDF = spark.createDataFrame(genTestData()) - originDF.createOrReplaceTempView("origin_table") - val fields: ListMap[String, String] = ListMap( - ("string_field", "string"), - ("int_field", "int"), - ("long_field", "long"), - ("float_field", "float"), - ("double_field", "double"), - ("short_field", "short"), - ("byte_field", "byte"), - ("boolean_field", "boolean"), - ("decimal_field", "decimal(23,12)"), - ("date_field", "date") - ) - - for (format <- formats) { + nativeWriteWithOriginalView(genTestData(), "origin_table") { + format => val table_name = table_name_template.format(format) - spark.sql(s"drop table IF EXISTS $table_name") val table_create_sql = s"create table $table_name using $format as select " + - fields + fields_ .map(f => s"${f._1}") .mkString(",") + " from origin_table" - spark.sql(table_create_sql) - spark.sql(s"drop table IF EXISTS $table_name") + val insert_sql = + s"create table $table_name as select " + + fields_ + .map(f => s"${f._1}") + .mkString(",") + + " from origin_table" + withDestinationTable(table_name, table_create_sql) { + spark.sql(s"drop table IF EXISTS $table_name") - try { - val table_create_sql = - s"create table $table_name as select " + - fields - .map(f => s"${f._1}") - .mkString(",") + - " from origin_table" - spark.sql(table_create_sql) - } catch { - case _: UnsupportedOperationException => // expected - case _: Exception => fail("should not throw exception") + try { + // FIXME: using checkNativeWrite + spark.sql(insert_sql) + } catch { + case _: UnsupportedOperationException => // expected + case e: Exception => fail("should not throw exception", e) + } } - } } } test("test insert into partition, bigo's case which incur InsertIntoHiveTable") { - withSQLConf( - ("spark.gluten.sql.native.writer.enabled", "true"), - ("spark.sql.hive.convertMetastoreParquet", "false"), - ("spark.sql.hive.convertMetastoreOrc", "false"), - (GlutenConfig.GLUTEN_ENABLED.key, "true") - ) { - - val originDF = spark.createDataFrame(genTestData()) - originDF.createOrReplaceTempView("origin_table") - val fields: ListMap[String, String] = ListMap( - ("string_field", "string"), - ("int_field", "int"), - ("long_field", "long"), - ("float_field", "float"), - ("double_field", "double"), - ("short_field", "short"), - ("byte_field", "byte"), - ("boolean_field", "boolean"), - ("decimal_field", "decimal(23,12)"), - ("date_field", "date") - ) - - for (format <- formats) { - val table_name = table_name_template.format(format) - spark.sql(s"drop table IF EXISTS $table_name") - val table_create_sql = s"create table if not exists $table_name (" + fields - .map(f => s"${f._1} ${f._2}") - .mkString(",") + " ) partitioned by (another_date_field string)" + - s"stored as $format" + def destination(format: String): (String, String, String) = { + val table_name = table_name_template.format(format) + val table_create_sql = s"create table if not exists $table_name (" + fields_ + .map(f => s"${f._1} ${f._2}") + .mkString(",") + " ) partitioned by (another_date_field string)" + + s"stored as $format" + val insert_sql = + s"insert overwrite table $table_name " + + "partition(another_date_field = '2020-01-01') select " + + fields_.keys.mkString(",") + " from (select " + fields_.keys.mkString( + ",") + ", row_number() over (order by int_field desc) as rn " + + "from origin_table where float_field > 3 ) tt where rn <= 100" + (table_name, table_create_sql, insert_sql) + } - spark.sql(table_create_sql) - spark.sql( - s"insert overwrite table $table_name " + - "partition(another_date_field = '2020-01-01') select " - + fields.keys.mkString(",") + " from (select " + fields.keys.mkString( - ",") + ", row_number() over (order by int_field desc) as rn " + - "from origin_table where float_field > 3 ) tt where rn <= 100") + def nativeFormatWrite(format: String): Unit = { + val (table_name, table_create_sql, insert_sql) = destination(format) + withDestinationTable(table_name, table_create_sql) { + checkNativeWrite(insert_sql, checkNative = true) val files = recursiveListFiles(new File(getWarehouseDir + "/" + table_name)) .filter(_.getName.startsWith("part")) assert(files.length == 1) assert(files.head.getAbsolutePath.contains("another_date_field=2020-01-01")) } } + + nativeWriteWithOriginalView( + genTestData(), + "origin_table", + ("spark.sql.hive.convertMetastoreParquet", "false"), + ("spark.sql.hive.convertMetastoreOrc", "false"))(nativeFormatWrite) } test("test 1-col partitioned table") { + nativeWrite { + format => + { + val table_name = table_name_template.format(format) + val table_create_sql = + s"create table if not exists $table_name (" + + fields_ + .filterNot(e => e._1.equals("date_field")) + .map(f => s"${f._1} ${f._2}") + .mkString(",") + + " ) partitioned by (date_field date) " + + s"stored as $format" - withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) { - - val fields: ListMap[String, String] = ListMap( - ("string_field", "string"), - ("int_field", "int"), - ("long_field", "long"), - ("float_field", "float"), - ("double_field", "double"), - ("short_field", "short"), - ("byte_field", "byte"), - ("boolean_field", "boolean"), - ("decimal_field", "decimal(23,12)"), - ("date_field", "date") - ) - - for (format <- formats) { - val table_name = table_name_template.format(format) - val table_create_sql = - s"create table if not exists $table_name (" + - fields - .filterNot(e => e._1.equals("date_field")) - .map(f => s"${f._1} ${f._2}") - .mkString(",") + - " ) partitioned by (date_field date) " + - s"stored as $format" - - writeAndCheckRead( - table_name, - writeIntoNewTableWithSql(table_name, table_create_sql), - fields.keys.toSeq) - } + writeAndCheckRead( + table_name, + writeIntoNewTableWithSql(table_name, table_create_sql), + fields_.keys.toSeq) + } } } // even if disable native writer, this UT fail, spark bug??? ignore("test 1-col partitioned table, partitioned by already ordered column") { withSQLConf(("spark.gluten.sql.native.writer.enabled", "false")) { - val fields: ListMap[String, String] = ListMap( - ("string_field", "string"), - ("int_field", "int"), - ("long_field", "long"), - ("float_field", "float"), - ("double_field", "double"), - ("short_field", "short"), - ("byte_field", "byte"), - ("boolean_field", "boolean"), - ("decimal_field", "decimal(23,12)"), - ("date_field", "date") - ) val originDF = spark.createDataFrame(genTestData()) originDF.createOrReplaceTempView("origin_table") @@ -410,7 +370,7 @@ class GlutenClickHouseNativeWriteTableSuite val table_name = table_name_template.format(format) val table_create_sql = s"create table if not exists $table_name (" + - fields + fields_ .filterNot(e => e._1.equals("date_field")) .map(f => s"${f._1} ${f._2}") .mkString(",") + @@ -420,31 +380,27 @@ class GlutenClickHouseNativeWriteTableSuite spark.sql(s"drop table IF EXISTS $table_name") spark.sql(table_create_sql) spark.sql( - s"insert overwrite $table_name select ${fields.mkString(",")}" + + s"insert overwrite $table_name select ${fields_.mkString(",")}" + s" from origin_table order by date_field") } } } test("test 2-col partitioned table") { - withSQLConf( - ("spark.gluten.sql.native.writer.enabled", "true"), - (GlutenConfig.GLUTEN_ENABLED.key, "true")) { - - val fields: ListMap[String, String] = ListMap( - ("string_field", "string"), - ("int_field", "int"), - ("long_field", "long"), - ("float_field", "float"), - ("double_field", "double"), - ("short_field", "short"), - ("boolean_field", "boolean"), - ("decimal_field", "decimal(23,12)"), - ("date_field", "date"), - ("byte_field", "byte") - ) - - for (format <- formats) { + val fields: ListMap[String, String] = ListMap( + ("string_field", "string"), + ("int_field", "int"), + ("long_field", "long"), + ("float_field", "float"), + ("double_field", "double"), + ("short_field", "short"), + ("boolean_field", "boolean"), + ("decimal_field", "decimal(23,12)"), + ("date_field", "date"), + ("byte_field", "byte") + ) + nativeWrite { + format => val table_name = table_name_template.format(format) val table_create_sql = s"create table if not exists $table_name (" + @@ -458,7 +414,6 @@ class GlutenClickHouseNativeWriteTableSuite table_name, writeIntoNewTableWithSql(table_name, table_create_sql), fields.keys.toSeq) - } } } @@ -506,25 +461,21 @@ class GlutenClickHouseNativeWriteTableSuite // This test case will be failed with incorrect result randomly, ignore first. ignore("test hive parquet/orc table, all columns being partitioned. ") { - withSQLConf( - ("spark.gluten.sql.native.writer.enabled", "true"), - (GlutenConfig.GLUTEN_ENABLED.key, "true")) { - - val fields: ListMap[String, String] = ListMap( - ("date_field", "date"), - ("timestamp_field", "timestamp"), - ("string_field", "string"), - ("int_field", "int"), - ("long_field", "long"), - ("float_field", "float"), - ("double_field", "double"), - ("short_field", "short"), - ("byte_field", "byte"), - ("boolean_field", "boolean"), - ("decimal_field", "decimal(23,12)") - ) - - for (format <- formats) { + val fields: ListMap[String, String] = ListMap( + ("date_field", "date"), + ("timestamp_field", "timestamp"), + ("string_field", "string"), + ("int_field", "int"), + ("long_field", "long"), + ("float_field", "float"), + ("double_field", "double"), + ("short_field", "short"), + ("byte_field", "byte"), + ("boolean_field", "boolean"), + ("decimal_field", "decimal(23,12)") + ) + nativeWrite { + format => val table_name = table_name_template.format(format) val table_create_sql = s"create table if not exists $table_name (" + @@ -540,20 +491,15 @@ class GlutenClickHouseNativeWriteTableSuite table_name, writeIntoNewTableWithSql(table_name, table_create_sql), fields.keys.toSeq) - } } } - test(("test hive parquet/orc table with aggregated results")) { - withSQLConf( - ("spark.gluten.sql.native.writer.enabled", "true"), - (GlutenConfig.GLUTEN_ENABLED.key, "true")) { - - val fields: ListMap[String, String] = ListMap( - ("sum(int_field)", "bigint") - ) - - for (format <- formats) { + test("test hive parquet/orc table with aggregated results") { + val fields: ListMap[String, String] = ListMap( + ("sum(int_field)", "bigint") + ) + nativeWrite { + format => val table_name = table_name_template.format(format) val table_create_sql = s"create table if not exists $table_name (" + @@ -566,29 +512,12 @@ class GlutenClickHouseNativeWriteTableSuite table_name, writeIntoNewTableWithSql(table_name, table_create_sql), fields.keys.toSeq) - } } } test("test 1-col partitioned + 1-col bucketed table") { - withSQLConf( - ("spark.gluten.sql.native.writer.enabled", "true"), - (GlutenConfig.GLUTEN_ENABLED.key, "true")) { - - val fields: ListMap[String, String] = ListMap( - ("string_field", "string"), - ("int_field", "int"), - ("long_field", "long"), - ("float_field", "float"), - ("double_field", "double"), - ("short_field", "short"), - ("byte_field", "byte"), - ("boolean_field", "boolean"), - ("decimal_field", "decimal(23,12)"), - ("date_field", "date") - ) - - for (format <- formats) { + nativeWrite { + format => // spark write does not support bucketed table // https://issues.apache.org/jira/browse/SPARK-19256 val table_name = table_name_template.format(format) @@ -604,7 +533,7 @@ class GlutenClickHouseNativeWriteTableSuite .bucketBy(2, "byte_field") .saveAsTable(table_name) }, - fields.keys.toSeq + fields_.keys.toSeq ) assert( @@ -614,10 +543,8 @@ class GlutenClickHouseNativeWriteTableSuite .filter(!_.getName.equals("date_field=__HIVE_DEFAULT_PARTITION__")) .head .listFiles() - .filter(!_.isHidden) - .length == 2 + .count(!_.isHidden) == 2 ) // 2 bucket files - } } } @@ -745,8 +672,8 @@ class GlutenClickHouseNativeWriteTableSuite } test("test consecutive blocks having same partition value") { - withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) { - for (format <- formats) { + nativeWrite { + format => val table_name = table_name_template.format(format) spark.sql(s"drop table IF EXISTS $table_name") @@ -760,15 +687,14 @@ class GlutenClickHouseNativeWriteTableSuite .partitionBy("p") .saveAsTable(table_name) - val ret = spark.sql("select sum(id) from " + table_name).collect().apply(0).apply(0) + val ret = spark.sql(s"select sum(id) from $table_name").collect().apply(0).apply(0) assert(ret == 449985000) - } } } test("test decimal with rand()") { - withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) { - for (format <- formats) { + nativeWrite { + format => val table_name = table_name_template.format(format) spark.sql(s"drop table IF EXISTS $table_name") spark @@ -778,32 +704,30 @@ class GlutenClickHouseNativeWriteTableSuite .format(format) .partitionBy("p") .saveAsTable(table_name) - val ret = spark.sql("select max(p) from " + table_name).collect().apply(0).apply(0) - } + val ret = spark.sql(s"select max(p) from $table_name").collect().apply(0).apply(0) } } test("test partitioned by constant") { - withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) { - for (format <- formats) { - spark.sql(s"drop table IF EXISTS tmp_123_$format") - spark.sql( - s"create table tmp_123_$format(" + - s"x1 string, x2 bigint,x3 string, x4 bigint, x5 string )" + - s"partitioned by (day date) stored as $format") - - spark.sql( - s"insert into tmp_123_$format partition(day) " + - "select cast(id as string), id, cast(id as string), id, cast(id as string), " + - "'2023-05-09' from range(10000000)") - } + nativeWrite2 { + format => + val table_name = s"tmp_123_$format" + val create_sql = + s"""create table tmp_123_$format( + |x1 string, x2 bigint,x3 string, x4 bigint, x5 string ) + |partitioned by (day date) stored as $format""".stripMargin + val insert_sql = + s"""insert into tmp_123_$format partition(day) + |select cast(id as string), id, cast(id as string), + | id, cast(id as string), '2023-05-09' + |from range(10000000)""".stripMargin + (table_name, create_sql, insert_sql) } } test("test bucketed by constant") { - withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) { - - for (format <- formats) { + nativeWrite { + format => val table_name = table_name_template.format(format) spark.sql(s"drop table IF EXISTS $table_name") @@ -815,15 +739,13 @@ class GlutenClickHouseNativeWriteTableSuite .bucketBy(2, "p") .saveAsTable(table_name) - val ret = spark.sql("select count(*) from " + table_name).collect().apply(0).apply(0) - } + assertResult(10000000)(spark.table(table_name).count()) } } test("test consecutive null values being partitioned") { - withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) { - - for (format <- formats) { + nativeWrite { + format => val table_name = table_name_template.format(format) spark.sql(s"drop table IF EXISTS $table_name") @@ -835,14 +757,13 @@ class GlutenClickHouseNativeWriteTableSuite .partitionBy("p") .saveAsTable(table_name) - val ret = spark.sql("select count(*) from " + table_name).collect().apply(0).apply(0) - } + assertResult(30000)(spark.table(table_name).count()) } } test("test consecutive null values being bucketed") { - withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) { - for (format <- formats) { + nativeWrite { + format => val table_name = table_name_template.format(format) spark.sql(s"drop table IF EXISTS $table_name") @@ -854,78 +775,79 @@ class GlutenClickHouseNativeWriteTableSuite .bucketBy(2, "p") .saveAsTable(table_name) - val ret = spark.sql("select count(*) from " + table_name).collect().apply(0).apply(0) - } + assertResult(30000)(spark.table(table_name).count()) } } test("test native write with empty dataset") { - withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) { - for (format <- formats) { + nativeWrite2( + format => { val table_name = "t_" + format - spark.sql(s"drop table IF EXISTS $table_name") - spark.sql(s"create table $table_name (id int, str string) stored as $format") - spark.sql( - s"insert into $table_name select id, cast(id as string) from range(10)" + - " where id > 100") + ( + table_name, + s"create table $table_name (id int, str string) stored as $format", + s"insert into $table_name select id, cast(id as string) from range(10) where id > 100" + ) + }, + (table_name, _, _) => { + assertResult(0)(spark.table(table_name).count()) } - } + ) } test("test native write with union") { - withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) { - for (format <- formats) { + nativeWrite { + format => val table_name = "t_" + format - spark.sql(s"drop table IF EXISTS $table_name") - spark.sql(s"create table $table_name (id int, str string) stored as $format") - spark.sql( - s"insert overwrite table $table_name " + - "select id, cast(id as string) from range(10) union all " + - "select 10, '10' from range(10)") - spark.sql( - s"insert overwrite table $table_name " + - "select id, cast(id as string) from range(10) union all " + - "select 10, cast(id as string) from range(10)") - - } + withDestinationTable( + table_name, + s"create table $table_name (id int, str string) stored as $format") { + checkNativeWrite( + s"insert overwrite table $table_name " + + "select id, cast(id as string) from range(10) union all " + + "select 10, '10' from range(10)", + checkNative = true) + checkNativeWrite( + s"insert overwrite table $table_name " + + "select id, cast(id as string) from range(10) union all " + + "select 10, cast(id as string) from range(10)", + checkNative = true + ) + } } } test("test native write and non-native read consistency") { - withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) { - for (format <- formats) { - val table_name = "t_" + format - spark.sql(s"drop table IF EXISTS $table_name") - spark.sql(s"create table $table_name (id int, name string, info char(4)) stored as $format") - spark.sql( - s"insert overwrite table $table_name " + - "select id, cast(id as string), concat('aaa', cast(id as string)) from range(10)") + nativeWrite2( + { + format => + val table_name = "t_" + format + ( + table_name, + s"create table $table_name (id int, name string, info char(4)) stored as $format", + s"insert overwrite table $table_name " + + "select id, cast(id as string), concat('aaa', cast(id as string)) from range(10)" + ) + }, + (table_name, _, _) => compareResultsAgainstVanillaSpark( s"select * from $table_name", compareResult = true, _ => {}) - } - } + ) } test("GLUTEN-4316: fix crash on dynamic partition inserting") { - withSQLConf( - ("spark.gluten.sql.native.writer.enabled", "true"), - (GlutenConfig.GLUTEN_ENABLED.key, "true")) { - formats.foreach( - format => { - val tbl = "t_" + format - spark.sql(s"drop table IF EXISTS $tbl") - val sql1 = - s"create table $tbl(a int, b map, c struct) " + - s"partitioned by (day string) stored as $format" - val sql2 = s"insert overwrite $tbl partition (day) " + - s"select id as a, str_to_map(concat('t1:','a','&t2:','b'),'&',':'), " + - s"struct('1', null) as c, '2024-01-08' as day from range(10)" - spark.sql(sql1) - spark.sql(sql2) - }) + nativeWrite2 { + format => + val tbl = "t_" + format + val sql1 = + s"create table $tbl(a int, b map, c struct) " + + s"partitioned by (day string) stored as $format" + val sql2 = s"insert overwrite $tbl partition (day) " + + s"select id as a, str_to_map(concat('t1:','a','&t2:','b'),'&',':'), " + + s"struct('1', null) as c, '2024-01-08' as day from range(10)" + (tbl, sql1, sql2) } } - } diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala index 1d3bbec848bc..188995f11058 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTPCHSaltNullParquetSuite.scala @@ -2048,10 +2048,15 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr """ |select to_json(struct(cast(id as string), id, 1.1, 1.1f, 1.1d)) from range(3) |""".stripMargin + val sql1 = + """ + | select to_json(named_struct('name', concat('/val/', id))) from range(3) + |""".stripMargin // cast('nan' as double) output 'NaN' in Spark, 'nan' in CH // cast('inf' as double) output 'Infinity' in Spark, 'inf' in CH // ignore them temporarily runQueryAndCompare(sql)(checkGlutenOperatorMatch[ProjectExecTransformer]) + runQueryAndCompare(sql1)(checkGlutenOperatorMatch[ProjectExecTransformer]) } test("GLUTEN-3501: test json output format with struct contains null value") { @@ -2570,12 +2575,12 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr spark.sql("drop table test_tbl_5096") } - test("GLUTEN-5896: Bug fix greatest diff") { + test("GLUTEN-5896: Bug fix greatest/least diff") { val tbl_create_sql = "create table test_tbl_5896(id bigint, x1 int, x2 int, x3 int) using parquet" val tbl_insert_sql = "insert into test_tbl_5896 values(1, 12, NULL, 13), (2, NULL, NULL, NULL), (3, 11, NULL, NULL), (4, 10, 9, 8)" - val select_sql = "select id, greatest(x1, x2, x3) from test_tbl_5896" + val select_sql = "select id, greatest(x1, x2, x3), least(x1, x2, x3) from test_tbl_5896" spark.sql(tbl_create_sql) spark.sql(tbl_insert_sql) compareResultsAgainstVanillaSpark(select_sql, true, { _ => }) @@ -2638,5 +2643,55 @@ class GlutenClickHouseTPCHSaltNullParquetSuite extends GlutenClickHouseTPCHAbstr spark.sql("drop table test_tbl_5910_0") spark.sql("drop table test_tbl_5910_1") } + + test("GLUTEN-4451: Fix schema may be changed by filter") { + val create_sql = + """ + |create table if not exists test_tbl_4451( + | month_day string, + | month_dif int, + | is_month_new string, + | country string, + | os string, + | mr bigint + |) using parquet + |PARTITIONED BY ( + | day string, + | app_name string) + |""".stripMargin + val insert_sql1 = + "INSERT into test_tbl_4451 partition (day='2024-06-01', app_name='abc') " + + "values('2024-06-01', 0, '1', 'CN', 'iOS', 100)" + val insert_sql2 = + "INSERT into test_tbl_4451 partition (day='2024-06-01', app_name='abc') " + + "values('2024-06-01', 0, '1', 'CN', 'iOS', 50)" + val insert_sql3 = + "INSERT into test_tbl_4451 partition (day='2024-06-01', app_name='abc') " + + "values('2024-06-01', 1, '1', 'CN', 'iOS', 80)" + spark.sql(create_sql) + spark.sql(insert_sql1) + spark.sql(insert_sql2) + spark.sql(insert_sql3) + val select_sql = + """ + |SELECT * FROM ( + | SELECT + | month_day, + | country, + | if(os = 'ALite','Android',os) AS os, + | is_month_new, + | nvl(sum(if(month_dif = 0, mr, 0)),0) AS `month0_n`, + | nvl(sum(if(month_dif = 1, mr, 0)) / sum(if(month_dif = 0, mr, 0)),0) AS `month1_rate`, + | '2024-06-18' as day, + | app_name + | FROM test_tbl_4451 + | GROUP BY month_day,country,if(os = 'ALite','Android',os),is_month_new,app_name + |) tt + |WHERE month0_n > 0 AND month1_rate <= 1 AND os IN ('all','Android','iOS') + | AND app_name IS NOT NULL + |""".stripMargin + compareResultsAgainstVanillaSpark(select_sql, true, { _ => }) + spark.sql("drop table test_tbl_4451") + } } // scalastyle:on line.size.limit diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala index 09fa3ff109f2..1b3df81667a0 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala @@ -46,7 +46,7 @@ class GlutenClickHouseTPCHMetricsSuite extends GlutenClickHouseTPCHAbstractSuite .set("spark.io.compression.codec", "LZ4") .set("spark.sql.shuffle.partitions", "1") .set("spark.sql.autoBroadcastJoinThreshold", "10MB") - .set("spark.gluten.sql.columnar.backend.ch.runtime_config.logger.level", "DEBUG") + // .set("spark.gluten.sql.columnar.backend.ch.runtime_config.logger.level", "DEBUG") .set( "spark.gluten.sql.columnar.backend.ch.runtime_settings.input_format_parquet_max_block_size", s"$parquetMaxBlockSize") diff --git a/backends-clickhouse/src/test/scala/org/apache/spark/gluten/NativeWriteChecker.scala b/backends-clickhouse/src/test/scala/org/apache/spark/gluten/NativeWriteChecker.scala new file mode 100644 index 000000000000..79616d52d0bc --- /dev/null +++ b/backends-clickhouse/src/test/scala/org/apache/spark/gluten/NativeWriteChecker.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.gluten + +import org.apache.gluten.execution.GlutenClickHouseWholeStageTransformerSuite + +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.datasources.FakeRowAdaptor +import org.apache.spark.sql.util.QueryExecutionListener + +trait NativeWriteChecker extends GlutenClickHouseWholeStageTransformerSuite { + + def checkNativeWrite(sqlStr: String, checkNative: Boolean): Unit = { + var nativeUsed = false + + val queryListener = new QueryExecutionListener { + override def onFailure(f: String, qe: QueryExecution, e: Exception): Unit = {} + override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { + if (!nativeUsed) { + nativeUsed = if (isSparkVersionGE("3.4")) { + false + } else { + qe.executedPlan.find(_.isInstanceOf[FakeRowAdaptor]).isDefined + } + } + } + } + + try { + spark.listenerManager.register(queryListener) + spark.sql(sqlStr) + spark.sparkContext.listenerBus.waitUntilEmpty() + assertResult(checkNative)(nativeUsed) + } finally { + spark.listenerManager.unregister(queryListener) + } + } +} diff --git a/backends-velox/src/main/java/org/apache/gluten/udf/UdfJniWrapper.java b/backends-velox/src/main/java/org/apache/gluten/udf/UdfJniWrapper.java index 4b609769b2ab..8bfe8bad5c01 100644 --- a/backends-velox/src/main/java/org/apache/gluten/udf/UdfJniWrapper.java +++ b/backends-velox/src/main/java/org/apache/gluten/udf/UdfJniWrapper.java @@ -18,7 +18,5 @@ public class UdfJniWrapper { - public UdfJniWrapper() {} - - public native void getFunctionSignatures(); + public static native void getFunctionSignatures(); } diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxIteratorApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxIteratorApi.scala index 880e1e56b852..22862156c6b2 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxIteratorApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxIteratorApi.scala @@ -26,6 +26,7 @@ import org.apache.gluten.substrait.plan.PlanNode import org.apache.gluten.substrait.rel.{LocalFilesBuilder, LocalFilesNode, SplitInfo} import org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat import org.apache.gluten.utils._ +import org.apache.gluten.utils.iterator.Iterators import org.apache.gluten.vectorized._ import org.apache.spark.{SparkConf, TaskContext} @@ -36,7 +37,7 @@ import org.apache.spark.sql.catalyst.util.{DateFormatter, TimestampFormatter} import org.apache.spark.sql.connector.read.InputPartition import org.apache.spark.sql.execution.datasources.{FilePartition, PartitionedFile} import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.types.{BinaryType, DateType, Decimal, DecimalType, StructType, TimestampType} +import org.apache.spark.sql.types._ import org.apache.spark.sql.utils.OASPackageBridge.InputMetricsWrapper import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.ExecutorManager diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala index 41b56804b50b..81f06478cbb6 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala @@ -161,10 +161,9 @@ class VeloxListenerApi extends ListenerApi { private def initialize(conf: SparkConf, isDriver: Boolean): Unit = { SparkDirectoryUtil.init(conf) UDFResolver.resolveUdfConf(conf, isDriver = isDriver) - val debugJni = conf.getBoolean(GlutenConfig.GLUTEN_DEBUG_MODE, defaultValue = false) && - conf.getBoolean(GlutenConfig.GLUTEN_DEBUG_KEEP_JNI_WORKSPACE, defaultValue = false) - if (debugJni) { - JniWorkspace.enableDebug() + if (conf.getBoolean(GlutenConfig.GLUTEN_DEBUG_KEEP_JNI_WORKSPACE, defaultValue = false)) { + val debugDir = conf.get(GlutenConfig.GLUTEN_DEBUG_KEEP_JNI_WORKSPACE_DIR) + JniWorkspace.enableDebug(debugDir) } val loader = JniWorkspace.getDefault.libLoader diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala index 1f868c4c2044..b48da15683e8 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala @@ -827,15 +827,6 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { buf.result } - /** - * Generate extended columnar post-rules. - * - * @return - */ - override def genExtendedColumnarPostRules(): List[SparkSession => Rule[SparkPlan]] = { - SparkShimLoader.getSparkShims.getExtendedColumnarPostRules() ::: List() - } - override def genInjectPostHocResolutionRules(): List[SparkSession => Rule[LogicalPlan]] = { List(ArrowConvertorRule) } @@ -861,7 +852,9 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { Sig[VeloxBloomFilterMightContain](ExpressionNames.MIGHT_CONTAIN), Sig[VeloxBloomFilterAggregate](ExpressionNames.BLOOM_FILTER_AGG), Sig[TransformKeys](TRANSFORM_KEYS), - Sig[TransformValues](TRANSFORM_VALUES) + Sig[TransformValues](TRANSFORM_VALUES), + // For test purpose. + Sig[VeloxDummyExpression](VeloxDummyExpression.VELOX_DUMMY_EXPRESSION) ) } diff --git a/backends-velox/src/main/scala/org/apache/gluten/datasource/ArrowCSVFileFormat.scala b/backends-velox/src/main/scala/org/apache/gluten/datasource/ArrowCSVFileFormat.scala index 7c3ca8fc8cde..a8e65b0539c7 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/datasource/ArrowCSVFileFormat.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/datasource/ArrowCSVFileFormat.scala @@ -21,7 +21,8 @@ import org.apache.gluten.exception.SchemaMismatchException import org.apache.gluten.execution.RowToVeloxColumnarExec import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators import org.apache.gluten.memory.arrow.pool.ArrowNativeMemoryPool -import org.apache.gluten.utils.{ArrowUtil, Iterators} +import org.apache.gluten.utils.ArrowUtil +import org.apache.gluten.utils.iterator.Iterators import org.apache.gluten.vectorized.ArrowWritableColumnVector import org.apache.spark.TaskContext diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/RowToVeloxColumnarExec.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/RowToVeloxColumnarExec.scala index 5c9c5889bd13..d694f15fa9bd 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/execution/RowToVeloxColumnarExec.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/execution/RowToVeloxColumnarExec.scala @@ -22,7 +22,8 @@ import org.apache.gluten.exception.GlutenException import org.apache.gluten.exec.Runtimes import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators import org.apache.gluten.memory.nmm.NativeMemoryManagers -import org.apache.gluten.utils.{ArrowAbiUtil, Iterators} +import org.apache.gluten.utils.ArrowAbiUtil +import org.apache.gluten.utils.iterator.Iterators import org.apache.gluten.vectorized._ import org.apache.spark.broadcast.Broadcast diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxAppendBatchesExec.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxAppendBatchesExec.scala index 8c2834574204..4b4db703de7a 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxAppendBatchesExec.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxAppendBatchesExec.scala @@ -17,7 +17,8 @@ package org.apache.gluten.execution import org.apache.gluten.extension.GlutenPlan -import org.apache.gluten.utils.{Iterators, VeloxBatchAppender} +import org.apache.gluten.utils.VeloxBatchAppender +import org.apache.gluten.utils.iterator.Iterators import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxBroadcastBuildSideRDD.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxBroadcastBuildSideRDD.scala index 17d0522d0732..fe3c0b7e3938 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxBroadcastBuildSideRDD.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxBroadcastBuildSideRDD.scala @@ -16,7 +16,7 @@ */ package org.apache.gluten.execution -import org.apache.gluten.utils.Iterators +import org.apache.gluten.utils.iterator.Iterators import org.apache.spark.{broadcast, SparkContext} import org.apache.spark.sql.execution.joins.BuildSideRelation diff --git a/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxColumnarToRowExec.scala b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxColumnarToRowExec.scala index 77bf49727283..0d6714d3af92 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxColumnarToRowExec.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/execution/VeloxColumnarToRowExec.scala @@ -20,7 +20,7 @@ import org.apache.gluten.columnarbatch.ColumnarBatches import org.apache.gluten.exception.GlutenNotSupportException import org.apache.gluten.extension.ValidationResult import org.apache.gluten.memory.nmm.NativeMemoryManagers -import org.apache.gluten.utils.Iterators +import org.apache.gluten.utils.iterator.Iterators import org.apache.gluten.vectorized.NativeColumnarToRowJniWrapper import org.apache.spark.broadcast.Broadcast diff --git a/backends-velox/src/main/scala/org/apache/gluten/expression/DummyExpression.scala b/backends-velox/src/main/scala/org/apache/gluten/expression/DummyExpression.scala new file mode 100644 index 000000000000..e2af66b599d3 --- /dev/null +++ b/backends-velox/src/main/scala/org/apache/gluten/expression/DummyExpression.scala @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.expression + +import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow} +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.types.DataType + +abstract class DummyExpression(child: Expression) extends UnaryExpression with Serializable { + private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType, nullable) + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = + defineCodeGen(ctx, ev, c => c) + + override def dataType: DataType = child.dataType + + override def eval(input: InternalRow): Any = { + assert(input.numFields == 1, "The input row of DummyExpression should have only 1 field.") + accessor(input, 0) + } +} + +// Can be used as a wrapper to force fall back the original expression to mock the fallback behavior +// of an supported expression in Gluten which fails native validation. +case class VeloxDummyExpression(child: Expression) + extends DummyExpression(child) + with Transformable { + override def getTransformer( + childrenTransformers: Seq[ExpressionTransformer]): ExpressionTransformer = { + if (childrenTransformers.size != children.size) { + throw new IllegalStateException( + this.getClass.getSimpleName + + ": getTransformer called before children transformer initialized.") + } + + GenericExpressionTransformer( + VeloxDummyExpression.VELOX_DUMMY_EXPRESSION, + childrenTransformers, + this) + } + + override protected def withNewChildInternal(newChild: Expression): Expression = copy(newChild) +} + +object VeloxDummyExpression { + val VELOX_DUMMY_EXPRESSION = "velox_dummy_expression" + + private val identifier = new FunctionIdentifier(VELOX_DUMMY_EXPRESSION) + + def registerFunctions(registry: FunctionRegistry): Unit = { + registry.registerFunction( + identifier, + new ExpressionInfo(classOf[VeloxDummyExpression].getName, VELOX_DUMMY_EXPRESSION), + (e: Seq[Expression]) => VeloxDummyExpression(e.head) + ) + } + + def unregisterFunctions(registry: FunctionRegistry): Unit = { + registry.dropFunction(identifier) + } +} diff --git a/backends-velox/src/main/scala/org/apache/spark/api/python/ColumnarArrowEvalPythonExec.scala b/backends-velox/src/main/scala/org/apache/spark/api/python/ColumnarArrowEvalPythonExec.scala index d5639057dac8..88280ff2edde 100644 --- a/backends-velox/src/main/scala/org/apache/spark/api/python/ColumnarArrowEvalPythonExec.scala +++ b/backends-velox/src/main/scala/org/apache/spark/api/python/ColumnarArrowEvalPythonExec.scala @@ -20,7 +20,8 @@ import org.apache.gluten.columnarbatch.ColumnarBatches import org.apache.gluten.exception.GlutenException import org.apache.gluten.extension.GlutenPlan import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators -import org.apache.gluten.utils.{Iterators, PullOutProjectHelper} +import org.apache.gluten.utils.PullOutProjectHelper +import org.apache.gluten.utils.iterator.Iterators import org.apache.gluten.vectorized.ArrowWritableColumnVector import org.apache.spark.{ContextAwareIterator, SparkEnv, TaskContext} diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala index 7385c53d61b3..cb65b7504bfc 100644 --- a/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala +++ b/backends-velox/src/main/scala/org/apache/spark/sql/execution/ColumnarCachedBatchSerializer.scala @@ -23,7 +23,8 @@ import org.apache.gluten.exec.Runtimes import org.apache.gluten.execution.{RowToVeloxColumnarExec, VeloxColumnarToRowExec} import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators import org.apache.gluten.memory.nmm.NativeMemoryManagers -import org.apache.gluten.utils.{ArrowAbiUtil, Iterators} +import org.apache.gluten.utils.ArrowAbiUtil +import org.apache.gluten.utils.iterator.Iterators import org.apache.gluten.vectorized.ColumnarBatchSerializerJniWrapper import org.apache.spark.internal.Logging diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/execution/datasources/VeloxWriteQueue.scala b/backends-velox/src/main/scala/org/apache/spark/sql/execution/datasources/VeloxWriteQueue.scala index 089db1da1dee..b2905e157554 100644 --- a/backends-velox/src/main/scala/org/apache/spark/sql/execution/datasources/VeloxWriteQueue.scala +++ b/backends-velox/src/main/scala/org/apache/spark/sql/execution/datasources/VeloxWriteQueue.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.gluten.datasource.DatasourceJniWrapper -import org.apache.gluten.utils.Iterators +import org.apache.gluten.utils.iterator.Iterators import org.apache.gluten.vectorized.ColumnarBatchInIterator import org.apache.spark.TaskContext diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala index 915fc554584c..8a549c9b4ea9 100644 --- a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala +++ b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala @@ -27,7 +27,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, ExpressionInfo} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, ExpressionInfo, Unevaluable} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.types.DataTypeUtils @@ -94,7 +94,8 @@ case class UDFExpression( dataType: DataType, nullable: Boolean, children: Seq[Expression]) - extends Transformable { + extends Unevaluable + with Transformable { override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): Expression = { this.copy(children = newChildren) @@ -326,7 +327,7 @@ object UDFResolver extends Logging { case None => Seq.empty case Some(_) => - new UdfJniWrapper().getFunctionSignatures() + UdfJniWrapper.getFunctionSignatures() UDFNames.map { name => diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala index 11eaa3289cab..a2baf95ecdc0 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/ScalarFunctionsValidateSuite.scala @@ -16,6 +16,7 @@ */ package org.apache.gluten.execution +import org.apache.spark.sql.execution.ProjectExec import org.apache.spark.sql.types._ import java.sql.Timestamp @@ -156,24 +157,28 @@ class ScalarFunctionsValidateSuite extends FunctionsValidateTest { checkLengthAndPlan(df, 1) } - test("greatest function") { - val df = runQueryAndCompare( - "SELECT greatest(l_orderkey, l_orderkey)" + - "from lineitem limit 1")(checkGlutenOperatorMatch[ProjectExecTransformer]) - } - - test("least function") { - val df = runQueryAndCompare( - "SELECT least(l_orderkey, l_orderkey)" + - "from lineitem limit 1")(checkGlutenOperatorMatch[ProjectExecTransformer]) - } - test("Test greatest function") { runQueryAndCompare( "SELECT greatest(l_orderkey, l_orderkey)" + "from lineitem limit 1") { checkGlutenOperatorMatch[ProjectExecTransformer] } + withTempPath { + path => + spark + .sql("""SELECT * + FROM VALUES (CAST(5.345 AS DECIMAL(6, 2)), CAST(5.35 AS DECIMAL(5, 4))), + (CAST(5.315 AS DECIMAL(6, 2)), CAST(5.355 AS DECIMAL(5, 4))), + (CAST(3.345 AS DECIMAL(6, 2)), CAST(4.35 AS DECIMAL(5, 4))) AS data(a, b);""") + .write + .parquet(path.getCanonicalPath) + + spark.read.parquet(path.getCanonicalPath).createOrReplaceTempView("view") + + runQueryAndCompare("SELECT greatest(a, b) from view") { + checkGlutenOperatorMatch[ProjectExecTransformer] + } + } } test("Test least function") { @@ -182,6 +187,22 @@ class ScalarFunctionsValidateSuite extends FunctionsValidateTest { "from lineitem limit 1") { checkGlutenOperatorMatch[ProjectExecTransformer] } + withTempPath { + path => + spark + .sql("""SELECT * + FROM VALUES (CAST(5.345 AS DECIMAL(6, 2)), CAST(5.35 AS DECIMAL(5, 4))), + (CAST(5.315 AS DECIMAL(6, 2)), CAST(5.355 AS DECIMAL(5, 4))), + (CAST(3.345 AS DECIMAL(6, 2)), CAST(4.35 AS DECIMAL(5, 4))) AS data(a, b);""") + .write + .parquet(path.getCanonicalPath) + + spark.read.parquet(path.getCanonicalPath).createOrReplaceTempView("view") + + runQueryAndCompare("SELECT least(a, b) from view") { + checkGlutenOperatorMatch[ProjectExecTransformer] + } + } } test("Test hash function") { @@ -1145,7 +1166,18 @@ class ScalarFunctionsValidateSuite extends FunctionsValidateTest { runQueryAndCompare( "SELECT a, window.start, window.end, count(*) as cnt FROM" + " string_timestamp GROUP by a, window(b, '5 minutes') ORDER BY a, start;") { - checkGlutenOperatorMatch[ProjectExecTransformer] + df => + val executedPlan = getExecutedPlan(df) + assert( + executedPlan.exists(plan => plan.isInstanceOf[ProjectExecTransformer]), + s"Expect ProjectExecTransformer exists " + + s"in executedPlan:\n ${executedPlan.last}" + ) + assert( + !executedPlan.exists(plan => plan.isInstanceOf[ProjectExec]), + s"Expect ProjectExec doesn't exist " + + s"in executedPlan:\n ${executedPlan.last}" + ) } } } diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala index a892b6f313a4..9b47a519cd28 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala @@ -19,6 +19,7 @@ package org.apache.gluten.execution import org.apache.gluten.GlutenConfig import org.apache.gluten.datasource.ArrowCSVFileFormat import org.apache.gluten.execution.datasource.v2.ArrowBatchScanExec +import org.apache.gluten.expression.VeloxDummyExpression import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.spark.SparkConf @@ -45,6 +46,12 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla override def beforeAll(): Unit = { super.beforeAll() createTPCHNotNullTables() + VeloxDummyExpression.registerFunctions(spark.sessionState.functionRegistry) + } + + override def afterAll(): Unit = { + VeloxDummyExpression.unregisterFunctions(spark.sessionState.functionRegistry) + super.afterAll() } override protected def sparkConf: SparkConf = { @@ -66,14 +73,20 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla test("select_part_column") { val df = runQueryAndCompare("select l_shipdate, l_orderkey from lineitem limit 1") { - df => { assert(df.schema.fields.length == 2) } + df => + { + assert(df.schema.fields.length == 2) + } } checkLengthAndPlan(df, 1) } test("select_as") { val df = runQueryAndCompare("select l_shipdate as my_col from lineitem limit 1") { - df => { assert(df.schema.fieldNames(0).equals("my_col")) } + df => + { + assert(df.schema.fieldNames(0).equals("my_col")) + } } checkLengthAndPlan(df, 1) } @@ -1074,6 +1087,13 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla // No ProjectExecTransformer is introduced. checkSparkOperatorChainMatch[GenerateExecTransformer, FilterExecTransformer] } + + runQueryAndCompare( + s""" + |SELECT $func(${VeloxDummyExpression.VELOX_DUMMY_EXPRESSION}(a)) from t2; + |""".stripMargin) { + checkGlutenOperatorMatch[GenerateExecTransformer] + } } } } diff --git a/cpp-ch/clickhouse.version b/cpp-ch/clickhouse.version index 1e3ac8d88ea9..54d0a74c5bb4 100644 --- a/cpp-ch/clickhouse.version +++ b/cpp-ch/clickhouse.version @@ -1,3 +1,4 @@ CH_ORG=Kyligence -CH_BRANCH=rebase_ch/20240620 -CH_COMMIT=f9c3886a767 +CH_BRANCH=rebase_ch/20240621 +CH_COMMIT=c811cbb985f + diff --git a/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp b/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp new file mode 100644 index 000000000000..5eb3a0b36057 --- /dev/null +++ b/cpp-ch/local-engine/AggregateFunctions/AggregateFunctionSparkAvg.cpp @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include + +#include + +#include +#include + +namespace DB +{ +struct Settings; + +namespace ErrorCodes +{ + +} +} + +namespace local_engine +{ +using namespace DB; + + +DataTypePtr getSparkAvgReturnType(const DataTypePtr & arg_type) +{ + const UInt32 precision_value = std::min(getDecimalPrecision(*arg_type) + 4, DecimalUtils::max_precision); + const auto scale_value = std::min(getDecimalScale(*arg_type) + 4, precision_value); + return createDecimal(precision_value, scale_value); +} + +template +requires is_decimal +class AggregateFunctionSparkAvg final : public AggregateFunctionAvg +{ +public: + using Base = AggregateFunctionAvg; + + explicit AggregateFunctionSparkAvg(const DataTypes & argument_types_, UInt32 num_scale_, UInt32 round_scale_) + : Base(argument_types_, createResultType(argument_types_, num_scale_, round_scale_), num_scale_) + , num_scale(num_scale_) + , round_scale(round_scale_) + { + } + + DataTypePtr createResultType(const DataTypes & argument_types_, UInt32 num_scale_, UInt32 round_scale_) + { + const DataTypePtr & data_type = argument_types_[0]; + const UInt32 precision_value = std::min(getDecimalPrecision(*data_type) + 4, DecimalUtils::max_precision); + const auto scale_value = std::min(num_scale_ + 4, precision_value); + return createDecimal(precision_value, scale_value); + } + + void insertResultInto(AggregateDataPtr __restrict place, IColumn & to, Arena *) const override + { + const DataTypePtr & result_type = this->getResultType(); + auto result_scale = getDecimalScale(*result_type); + WhichDataType which(result_type); + if (which.isDecimal32()) + { + assert_cast &>(to).getData().push_back( + divideDecimalAndUInt(this->data(place), num_scale, result_scale, round_scale)); + } + else if (which.isDecimal64()) + { + assert_cast &>(to).getData().push_back( + divideDecimalAndUInt(this->data(place), num_scale, result_scale, round_scale)); + } + else if (which.isDecimal128()) + { + assert_cast &>(to).getData().push_back( + divideDecimalAndUInt(this->data(place), num_scale, result_scale, round_scale)); + } + else + { + assert_cast &>(to).getData().push_back( + divideDecimalAndUInt(this->data(place), num_scale, result_scale, round_scale)); + } + } + + String getName() const override { return "sparkAvg"; } + +private: + Int128 NO_SANITIZE_UNDEFINED + divideDecimalAndUInt(AvgFraction, UInt64> avg, UInt32 num_scale, UInt32 result_scale, UInt32 round_scale) const + { + auto value = avg.numerator.value; + if (result_scale > num_scale) + { + auto diff = DecimalUtils::scaleMultiplier>(result_scale - num_scale); + value = value * diff; + } + else if (result_scale < num_scale) + { + auto diff = DecimalUtils::scaleMultiplier>(num_scale - result_scale); + value = value / diff; + } + + auto result = value / avg.denominator; + + if (round_scale > result_scale) + return result; + + auto round_diff = DecimalUtils::scaleMultiplier>(result_scale - round_scale); + return (result + round_diff / 2) / round_diff * round_diff; + } + +private: + UInt32 num_scale; + UInt32 round_scale; +}; + +AggregateFunctionPtr +createAggregateFunctionSparkAvg(const std::string & name, const DataTypes & argument_types, const Array & parameters, const Settings * settings) +{ + assertNoParameters(name, parameters); + assertUnary(name, argument_types); + + AggregateFunctionPtr res; + const DataTypePtr & data_type = argument_types[0]; + if (!isDecimal(data_type)) + throw Exception( + ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument for aggregate function {}", data_type->getName(), name); + + bool allowPrecisionLoss = settings->get(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS).get(); + const UInt32 p1 = DB::getDecimalPrecision(*data_type); + const UInt32 s1 = DB::getDecimalScale(*data_type); + auto [p2, s2] = GlutenDecimalUtils::LONG_DECIMAL; + auto [_, round_scale] = GlutenDecimalUtils::dividePrecisionScale(p1, s1, p2, s2, allowPrecisionLoss); + + res.reset(createWithDecimalType(*data_type, argument_types, getDecimalScale(*data_type), round_scale)); + return res; +} + +void registerAggregateFunctionSparkAvg(AggregateFunctionFactory & factory) +{ + factory.registerFunction("sparkAvg", createAggregateFunctionSparkAvg); +} + +} diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp b/cpp-ch/local-engine/Common/CHUtil.cpp index 937beae99a6b..148e78bfbc79 100644 --- a/cpp-ch/local-engine/Common/CHUtil.cpp +++ b/cpp-ch/local-engine/Common/CHUtil.cpp @@ -77,6 +77,7 @@ namespace ErrorCodes { extern const int BAD_ARGUMENTS; extern const int UNKNOWN_TYPE; +extern const int CANNOT_PARSE_PROTOBUF_SCHEMA; } } @@ -466,17 +467,17 @@ String QueryPipelineUtil::explainPipeline(DB::QueryPipeline & pipeline) using namespace DB; -std::map BackendInitializerUtil::getBackendConfMap(std::string * plan) +std::map BackendInitializerUtil::getBackendConfMap(const std::string & plan) { std::map ch_backend_conf; - if (plan == nullptr) + if (plan.empty()) return ch_backend_conf; /// Parse backend configs from plan extensions do { auto plan_ptr = std::make_unique(); - auto success = plan_ptr->ParseFromString(*plan); + auto success = plan_ptr->ParseFromString(plan); if (!success) break; @@ -623,7 +624,9 @@ void BackendInitializerUtil::initSettings(std::map & b { /// Initialize default setting. settings.set("date_time_input_format", "best_effort"); - settings.set("mergetree.merge_after_insert", true); + settings.set(MERGETREE_MERGE_AFTER_INSERT, true); + settings.set(MERGETREE_INSERT_WITHOUT_LOCAL_STORAGE, false); + settings.set(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS, true); for (const auto & [key, value] : backend_conf_map) { @@ -663,8 +666,12 @@ void BackendInitializerUtil::initSettings(std::map & b settings.set("session_timezone", time_zone_val); LOG_DEBUG(&Poco::Logger::get("CHUtil"), "Set settings key:{} value:{}", "session_timezone", time_zone_val); } + else if (key == DECIMAL_OPERATIONS_ALLOW_PREC_LOSS) + { + settings.set(key, toField(key, value)); + LOG_DEBUG(&Poco::Logger::get("CHUtil"), "Set settings key:{} value:{}", key, value); + } } - /// Finally apply some fixed kvs to settings. settings.set("join_use_nulls", true); settings.set("input_format_orc_allow_missing_columns", true); @@ -686,6 +693,7 @@ void BackendInitializerUtil::initSettings(std::map & b settings.set("output_format_json_quote_64bit_integers", false); settings.set("output_format_json_quote_denormals", true); settings.set("output_format_json_skip_null_value_in_named_tuples", true); + settings.set("output_format_json_escape_forward_slashes", false); settings.set("function_json_value_return_type_allow_complex", true); settings.set("function_json_value_return_type_allow_nullable", true); settings.set("precise_float_parsing", true); @@ -750,7 +758,7 @@ void BackendInitializerUtil::initContexts(DB::Context::ConfigurationPtr config) size_t index_uncompressed_cache_size = config->getUInt64("index_uncompressed_cache_size", DEFAULT_INDEX_UNCOMPRESSED_CACHE_MAX_SIZE); double index_uncompressed_cache_size_ratio = config->getDouble("index_uncompressed_cache_size_ratio", DEFAULT_INDEX_UNCOMPRESSED_CACHE_SIZE_RATIO); global_context->setIndexUncompressedCache(index_uncompressed_cache_policy, index_uncompressed_cache_size, index_uncompressed_cache_size_ratio); - + String index_mark_cache_policy = config->getString("index_mark_cache_policy", DEFAULT_INDEX_MARK_CACHE_POLICY); size_t index_mark_cache_size = config->getUInt64("index_mark_cache_size", DEFAULT_INDEX_MARK_CACHE_MAX_SIZE); double index_mark_cache_size_ratio = config->getDouble("index_mark_cache_size_ratio", DEFAULT_INDEX_MARK_CACHE_SIZE_RATIO); @@ -786,6 +794,7 @@ void BackendInitializerUtil::updateNewSettings(const DB::ContextMutablePtr & con extern void registerAggregateFunctionCombinatorPartialMerge(AggregateFunctionCombinatorFactory &); extern void registerAggregateFunctionsBloomFilter(AggregateFunctionFactory &); +extern void registerAggregateFunctionSparkAvg(AggregateFunctionFactory &); extern void registerFunctions(FunctionFactory &); void registerAllFunctions() @@ -795,7 +804,7 @@ void registerAllFunctions() DB::registerAggregateFunctions(); auto & agg_factory = AggregateFunctionFactory::instance(); registerAggregateFunctionsBloomFilter(agg_factory); - + registerAggregateFunctionSparkAvg(agg_factory); { /// register aggregate function combinators from local_engine auto & factory = AggregateFunctionCombinatorFactory::instance(); @@ -840,14 +849,8 @@ void BackendInitializerUtil::initCompiledExpressionCache(DB::Context::Configurat #endif } -void BackendInitializerUtil::init_json(std::string * plan_json) -{ - auto plan_ptr = std::make_unique(); - google::protobuf::util::JsonStringToMessage(plan_json->c_str(), plan_ptr.get()); - return init(new String(plan_ptr->SerializeAsString())); -} -void BackendInitializerUtil::init(std::string * plan) +void BackendInitializerUtil::init(const std::string & plan) { std::map backend_conf_map = getBackendConfMap(plan); DB::Context::ConfigurationPtr config = initConfig(backend_conf_map); @@ -905,7 +908,7 @@ void BackendInitializerUtil::init(std::string * plan) }); } -void BackendInitializerUtil::updateConfig(const DB::ContextMutablePtr & context, std::string * plan) +void BackendInitializerUtil::updateConfig(const DB::ContextMutablePtr & context, const std::string & plan) { std::map backend_conf_map = getBackendConfMap(plan); @@ -919,7 +922,10 @@ void BackendInitializerUtil::updateConfig(const DB::ContextMutablePtr & context, void BackendFinalizerUtil::finalizeGlobally() { - // Make sure client caches release before ClientCacheRegistry + /// Make sure that all active LocalExecutor stop before spark executor shutdown, otherwise crash map happen. + LocalExecutor::cancelAll(); + + /// Make sure client caches release before ClientCacheRegistry ReadBufferBuilderFactory::instance().clean(); StorageMergeTreeFactory::clear(); auto & global_context = SerializedPlanParser::global_context; diff --git a/cpp-ch/local-engine/Common/CHUtil.h b/cpp-ch/local-engine/Common/CHUtil.h index 50de9461f4de..0321d410a7d5 100644 --- a/cpp-ch/local-engine/Common/CHUtil.h +++ b/cpp-ch/local-engine/Common/CHUtil.h @@ -35,7 +35,12 @@ class QueryPlan; namespace local_engine { -static const std::unordered_set BOOL_VALUE_SETTINGS{"mergetree.merge_after_insert"}; +static const String MERGETREE_INSERT_WITHOUT_LOCAL_STORAGE = "mergetree.insert_without_local_storage"; +static const String MERGETREE_MERGE_AFTER_INSERT = "mergetree.merge_after_insert"; +static const std::string DECIMAL_OPERATIONS_ALLOW_PREC_LOSS = "spark.sql.decimalOperations.allowPrecisionLoss"; + +static const std::unordered_set BOOL_VALUE_SETTINGS{ + MERGETREE_MERGE_AFTER_INSERT, MERGETREE_INSERT_WITHOUT_LOCAL_STORAGE, DECIMAL_OPERATIONS_ALLOW_PREC_LOSS}; static const std::unordered_set LONG_VALUE_SETTINGS{ "optimize.maxfilesize", "optimize.minFileSize", "mergetree.max_num_part_per_merge_task"}; @@ -135,9 +140,8 @@ class BackendInitializerUtil /// Initialize two kinds of resources /// 1. global level resources like global_context/shared_context, notice that they can only be initialized once in process lifetime /// 2. session level resources like settings/configs, they can be initialized multiple times following the lifetime of executor/driver - static void init(std::string * plan); - static void init_json(std::string * plan_json); - static void updateConfig(const DB::ContextMutablePtr &, std::string *); + static void init(const std::string & plan); + static void updateConfig(const DB::ContextMutablePtr &, const std::string &); // use excel text parser @@ -194,7 +198,7 @@ class BackendInitializerUtil static void updateNewSettings(const DB::ContextMutablePtr &, const DB::Settings &); - static std::map getBackendConfMap(std::string * plan); + static std::map getBackendConfMap(const std::string & plan); inline static std::once_flag init_flag; inline static Poco::Logger * logger; @@ -281,10 +285,7 @@ class ConcurrentDeque return deq.empty(); } - std::deque unsafeGet() - { - return deq; - } + std::deque unsafeGet() { return deq; } private: std::deque deq; diff --git a/cpp-ch/local-engine/Common/GlutenDecimalUtils.h b/cpp-ch/local-engine/Common/GlutenDecimalUtils.h new file mode 100644 index 000000000000..32af66ec590e --- /dev/null +++ b/cpp-ch/local-engine/Common/GlutenDecimalUtils.h @@ -0,0 +1,108 @@ +/* +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + + +namespace local_engine +{ + +class GlutenDecimalUtils +{ +public: + static constexpr size_t MAX_PRECISION = 38; + static constexpr size_t MAX_SCALE = 38; + static constexpr auto system_Default = std::tuple(MAX_PRECISION, 18); + static constexpr auto user_Default = std::tuple(10, 0); + static constexpr size_t MINIMUM_ADJUSTED_SCALE = 6; + + // The decimal types compatible with other numeric types + static constexpr auto BOOLEAN_DECIMAL = std::tuple(1, 0); + static constexpr auto BYTE_DECIMAL = std::tuple(3, 0); + static constexpr auto SHORT_DECIMAL = std::tuple(5, 0); + static constexpr auto INT_DECIMAL = std::tuple(10, 0); + static constexpr auto LONG_DECIMAL = std::tuple(20, 0); + static constexpr auto FLOAT_DECIMAL = std::tuple(14, 7); + static constexpr auto DOUBLE_DECIMAL = std::tuple(30, 15); + static constexpr auto BIGINT_DECIMAL = std::tuple(MAX_PRECISION, 0); + + static std::tuple adjustPrecisionScale(size_t precision, size_t scale) + { + if (precision <= MAX_PRECISION) + { + // Adjustment only needed when we exceed max precision + return std::tuple(precision, scale); + } + else if (scale < 0) + { + // Decimal can have negative scale (SPARK-24468). In this case, we cannot allow a precision + // loss since we would cause a loss of digits in the integer part. + // In this case, we are likely to meet an overflow. + return std::tuple(GlutenDecimalUtils::MAX_PRECISION, scale); + } + else + { + // Precision/scale exceed maximum precision. Result must be adjusted to MAX_PRECISION. + auto intDigits = precision - scale; + // If original scale is less than MINIMUM_ADJUSTED_SCALE, use original scale value; otherwise + // preserve at least MINIMUM_ADJUSTED_SCALE fractional digits + auto minScaleValue = std::min(scale, GlutenDecimalUtils::MINIMUM_ADJUSTED_SCALE); + // The resulting scale is the maximum between what is available without causing a loss of + // digits for the integer part of the decimal and the minimum guaranteed scale, which is + // computed above + auto adjustedScale = std::max(GlutenDecimalUtils::MAX_PRECISION - intDigits, minScaleValue); + + return std::tuple(GlutenDecimalUtils::MAX_PRECISION, adjustedScale); + } + } + + static std::tuple dividePrecisionScale(size_t p1, size_t s1, size_t p2, size_t s2, bool allowPrecisionLoss) + { + if (allowPrecisionLoss) + { + // Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1) + // Scale: max(6, s1 + p2 + 1) + const size_t intDig = p1 - s1 + s2; + const size_t scale = std::max(MINIMUM_ADJUSTED_SCALE, s1 + p2 + 1); + const size_t precision = intDig + scale; + return adjustPrecisionScale(precision, scale); + } + else + { + auto intDig = std::min(MAX_SCALE, p1 - s1 + s2); + auto decDig = std::min(MAX_SCALE, std::max(static_cast(6), s1 + p2 + 1)); + auto diff = (intDig + decDig) - MAX_SCALE; + if (diff > 0) + { + decDig -= diff / 2 + 1; + intDig = MAX_SCALE - decDig; + } + return std::tuple(intDig + decDig, decDig); + } + } + + static std::tuple widerDecimalType(const size_t p1, const size_t s1, const size_t p2, const size_t s2) + { + // max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2) + auto scale = std::max(s1, s2); + auto range = std::max(p1 - s1, p2 - s2); + return std::tuple(range + scale, scale); + } + +}; + +} diff --git a/cpp-ch/local-engine/Disks/ObjectStorages/GlutenDiskHDFS.cpp b/cpp-ch/local-engine/Disks/ObjectStorages/GlutenDiskHDFS.cpp index 07a7aa6bd006..f207ad232b4f 100644 --- a/cpp-ch/local-engine/Disks/ObjectStorages/GlutenDiskHDFS.cpp +++ b/cpp-ch/local-engine/Disks/ObjectStorages/GlutenDiskHDFS.cpp @@ -52,7 +52,15 @@ void GlutenDiskHDFS::createDirectories(const String & path) void GlutenDiskHDFS::removeDirectory(const String & path) { DiskObjectStorage::removeDirectory(path); - hdfsDelete(hdfs_object_storage->getHDFSFS(), path.c_str(), 1); + String abs_path = "/" + path; + hdfsDelete(hdfs_object_storage->getHDFSFS(), abs_path.c_str(), 1); +} + +void GlutenDiskHDFS::removeRecursive(const String & path) +{ + DiskObjectStorage::removeRecursive(path); + String abs_path = "/" + path; + hdfsDelete(hdfs_object_storage->getHDFSFS(), abs_path.c_str(), 1); } DiskObjectStoragePtr GlutenDiskHDFS::createDiskObjectStorage() diff --git a/cpp-ch/local-engine/Disks/ObjectStorages/GlutenDiskHDFS.h b/cpp-ch/local-engine/Disks/ObjectStorages/GlutenDiskHDFS.h index 222b9f8928a3..97a99f1deaba 100644 --- a/cpp-ch/local-engine/Disks/ObjectStorages/GlutenDiskHDFS.h +++ b/cpp-ch/local-engine/Disks/ObjectStorages/GlutenDiskHDFS.h @@ -57,6 +57,8 @@ class GlutenDiskHDFS : public DB::DiskObjectStorage void removeDirectory(const String & path) override; + void removeRecursive(const String & path) override; + DB::DiskObjectStoragePtr createDiskObjectStorage() override; std::unique_ptr writeFile(const String& path, size_t buf_size, DB::WriteMode mode, diff --git a/cpp-ch/local-engine/Functions/FunctionGreatestLeast.h b/cpp-ch/local-engine/Functions/FunctionGreatestLeast.h new file mode 100644 index 000000000000..6930c1d75b79 --- /dev/null +++ b/cpp-ch/local-engine/Functions/FunctionGreatestLeast.h @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +} +} +namespace local_engine +{ +template +class FunctionGreatestestLeast : public DB::FunctionLeastGreatestGeneric +{ +public: + bool useDefaultImplementationForNulls() const override { return false; } + virtual String getName() const = 0; + +private: + DB::DataTypePtr getReturnTypeImpl(const DB::DataTypes & types) const override + { + if (types.empty()) + throw DB::Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} cannot be called without arguments", getName()); + return makeNullable(getLeastSupertype(types)); + } + + DB::ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr & result_type, size_t input_rows_count) const override + { + size_t num_arguments = arguments.size(); + DB::Columns converted_columns(num_arguments); + for (size_t arg = 0; arg < num_arguments; ++arg) + converted_columns[arg] = castColumn(arguments[arg], result_type)->convertToFullColumnIfConst(); + auto result_column = result_type->createColumn(); + result_column->reserve(input_rows_count); + for (size_t row_num = 0; row_num < input_rows_count; ++row_num) + { + size_t best_arg = 0; + for (size_t arg = 1; arg < num_arguments; ++arg) + { + if constexpr (kind == DB::LeastGreatest::Greatest) + { + auto cmp_result = converted_columns[arg]->compareAt(row_num, row_num, *converted_columns[best_arg], -1); + if (cmp_result > 0) + best_arg = arg; + } + else + { + auto cmp_result = converted_columns[arg]->compareAt(row_num, row_num, *converted_columns[best_arg], 1); + if (cmp_result < 0) + best_arg = arg; + } + } + result_column->insertFrom(*converted_columns[best_arg], row_num); + } + return result_column; + } +}; + +} diff --git a/cpp-ch/local-engine/Functions/SparkArrayFlatten.cpp b/cpp-ch/local-engine/Functions/SparkArrayFlatten.cpp new file mode 100644 index 000000000000..d39bca5ea104 --- /dev/null +++ b/cpp-ch/local-engine/Functions/SparkArrayFlatten.cpp @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include + + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int ILLEGAL_TYPE_OF_ARGUMENT; + extern const int ILLEGAL_COLUMN; +} + +/// arrayFlatten([[1, 2, 3], [4, 5]]) = [1, 2, 3, 4, 5] - flatten array. +class SparkArrayFlatten : public IFunction +{ +public: + static constexpr auto name = "sparkArrayFlatten"; + + static FunctionPtr create(ContextPtr) { return std::make_shared(); } + + size_t getNumberOfArguments() const override { return 1; } + bool useDefaultImplementationForConstants() const override { return true; } + bool isSuitableForShortCircuitArgumentsExecution(const DataTypesWithConstInfo & /*arguments*/) const override { return true; } + + DataTypePtr getReturnTypeImpl(const DataTypes & arguments) const override + { + if (!isArray(arguments[0])) + throw Exception(ErrorCodes::ILLEGAL_TYPE_OF_ARGUMENT, "Illegal type {} of argument of function {}, expected Array", + arguments[0]->getName(), getName()); + + DataTypePtr nested_type = arguments[0]; + nested_type = checkAndGetDataType(removeNullable(nested_type).get())->getNestedType(); + return nested_type; + } + + ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override + { + /** We create an array column with array elements as the most deep elements of nested arrays, + * and construct offsets by selecting elements of most deep offsets by values of ancestor offsets. + * +Example 1: + +Source column: Array(Array(UInt8)): +Row 1: [[1, 2, 3], [4, 5]], Row 2: [[6], [7, 8]] +data: [1, 2, 3], [4, 5], [6], [7, 8] +offsets: 2, 4 +data.data: 1 2 3 4 5 6 7 8 +data.offsets: 3 5 6 8 + +Result column: Array(UInt8): +Row 1: [1, 2, 3, 4, 5], Row 2: [6, 7, 8] +data: 1 2 3 4 5 6 7 8 +offsets: 5 8 + +Result offsets are selected from the most deep (data.offsets) by previous deep (offsets) (and values are decremented by one): +3 5 6 8 + ^ ^ + +Example 2: + +Source column: Array(Array(Array(UInt8))): +Row 1: [[], [[1], [], [2, 3]]], Row 2: [[[4]]] + +most deep data: 1 2 3 4 + +offsets1: 2 3 +offsets2: 0 3 4 +- ^ ^ - select by prev offsets +offsets3: 1 1 3 4 +- ^ ^ - select by prev offsets + +result offsets: 3, 4 +result: Row 1: [1, 2, 3], Row2: [4] + */ + + const ColumnArray * src_col = checkAndGetColumn(arguments[0].column.get()); + + if (!src_col) + throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} in argument of function 'arrayFlatten'", + arguments[0].column->getName()); + + const IColumn::Offsets & src_offsets = src_col->getOffsets(); + + ColumnArray::ColumnOffsets::MutablePtr result_offsets_column; + const IColumn::Offsets * prev_offsets = &src_offsets; + const IColumn * prev_data = &src_col->getData(); + bool nullable = prev_data->isNullable(); + // when array has null element, return null + if (nullable) + { + const ColumnNullable * nullable_column = checkAndGetColumn(prev_data); + prev_data = nullable_column->getNestedColumnPtr().get(); + for (size_t i = 0; i < nullable_column->size(); i++) + { + if (nullable_column->isNullAt(i)) + { + auto res= nullable_column->cloneEmpty(); + res->insertManyDefaults(input_rows_count); + return res; + } + } + } + if (isNothing(prev_data->getDataType())) + return prev_data->cloneResized(input_rows_count); + // only flatten one dimension + if (const ColumnArray * next_col = checkAndGetColumn(prev_data)) + { + result_offsets_column = ColumnArray::ColumnOffsets::create(input_rows_count); + + IColumn::Offsets & result_offsets = result_offsets_column->getData(); + + const IColumn::Offsets * next_offsets = &next_col->getOffsets(); + + for (size_t i = 0; i < input_rows_count; ++i) + result_offsets[i] = (*next_offsets)[(*prev_offsets)[i] - 1]; /// -1 array subscript is Ok, see PaddedPODArray + prev_data = &next_col->getData(); + } + + auto res = ColumnArray::create( + prev_data->getPtr(), + result_offsets_column ? std::move(result_offsets_column) : src_col->getOffsetsPtr()); + if (nullable) + return makeNullable(res); + return res; + } + +private: + String getName() const override + { + return name; + } +}; + +REGISTER_FUNCTION(SparkArrayFlatten) +{ + factory.registerFunction(); +} + +} diff --git a/cpp-ch/local-engine/Functions/SparkFunctionGreatest.cpp b/cpp-ch/local-engine/Functions/SparkFunctionGreatest.cpp index 9577d65ec5f7..920fe1b9c9cc 100644 --- a/cpp-ch/local-engine/Functions/SparkFunctionGreatest.cpp +++ b/cpp-ch/local-engine/Functions/SparkFunctionGreatest.cpp @@ -14,58 +14,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include -#include -#include - -namespace DB -{ -namespace ErrorCodes -{ - extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; -} -} +#include namespace local_engine { -class SparkFunctionGreatest : public DB::FunctionLeastGreatestGeneric +class SparkFunctionGreatest : public FunctionGreatestestLeast { public: static constexpr auto name = "sparkGreatest"; static DB::FunctionPtr create(DB::ContextPtr) { return std::make_shared(); } SparkFunctionGreatest() = default; ~SparkFunctionGreatest() override = default; - bool useDefaultImplementationForNulls() const override { return false; } - -private: - DB::DataTypePtr getReturnTypeImpl(const DB::DataTypes & types) const override - { - if (types.empty()) - throw DB::Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} cannot be called without arguments", name); - return makeNullable(getLeastSupertype(types)); - } - - DB::ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr & result_type, size_t input_rows_count) const override + String getName() const override { - size_t num_arguments = arguments.size(); - DB::Columns converted_columns(num_arguments); - for (size_t arg = 0; arg < num_arguments; ++arg) - converted_columns[arg] = castColumn(arguments[arg], result_type)->convertToFullColumnIfConst(); - auto result_column = result_type->createColumn(); - result_column->reserve(input_rows_count); - for (size_t row_num = 0; row_num < input_rows_count; ++row_num) - { - size_t best_arg = 0; - for (size_t arg = 1; arg < num_arguments; ++arg) - { - auto cmp_result = converted_columns[arg]->compareAt(row_num, row_num, *converted_columns[best_arg], -1); - if (cmp_result > 0) - best_arg = arg; - } - result_column->insertFrom(*converted_columns[best_arg], row_num); - } - return result_column; - } + return name; + } }; REGISTER_FUNCTION(SparkGreatest) diff --git a/cpp-ch/local-engine/Functions/SparkFunctionLeast.cpp b/cpp-ch/local-engine/Functions/SparkFunctionLeast.cpp new file mode 100644 index 000000000000..70aafdf07209 --- /dev/null +++ b/cpp-ch/local-engine/Functions/SparkFunctionLeast.cpp @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +namespace local_engine +{ +class SparkFunctionLeast : public FunctionGreatestestLeast +{ +public: + static constexpr auto name = "sparkLeast"; + static DB::FunctionPtr create(DB::ContextPtr) { return std::make_shared(); } + SparkFunctionLeast() = default; + ~SparkFunctionLeast() override = default; + String getName() const override + { + return name; + } +}; + +REGISTER_FUNCTION(SparkLeast) +{ + factory.registerFunction(); +} +} diff --git a/cpp-ch/local-engine/Parser/CHColumnToSparkRow.cpp b/cpp-ch/local-engine/Parser/CHColumnToSparkRow.cpp index 2b4eb824a5fd..5bb66e4b3f9d 100644 --- a/cpp-ch/local-engine/Parser/CHColumnToSparkRow.cpp +++ b/cpp-ch/local-engine/Parser/CHColumnToSparkRow.cpp @@ -453,7 +453,7 @@ std::unique_ptr CHColumnToSparkRow::convertCHColumnToSparkRow(cons if (!block.columns()) throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "A block with empty columns"); std::unique_ptr spark_row_info = std::make_unique(block, masks); - spark_row_info->setBufferAddress(reinterpret_cast(alloc(spark_row_info->getTotalBytes(), 64))); + spark_row_info->setBufferAddress(static_cast(alloc(spark_row_info->getTotalBytes(), 64))); // spark_row_info->setBufferAddress(alignedAlloc(spark_row_info->getTotalBytes(), 64)); memset(spark_row_info->getBufferAddress(), 0, spark_row_info->getTotalBytes()); for (auto col_idx = 0; col_idx < spark_row_info->getNumCols(); col_idx++) diff --git a/cpp-ch/local-engine/Parser/FilterRelParser.cpp b/cpp-ch/local-engine/Parser/FilterRelParser.cpp index 4c71cc3126af..e0098f747c2a 100644 --- a/cpp-ch/local-engine/Parser/FilterRelParser.cpp +++ b/cpp-ch/local-engine/Parser/FilterRelParser.cpp @@ -59,7 +59,12 @@ DB::QueryPlanPtr FilterRelParser::parse(DB::QueryPlanPtr query_plan, const subst filter_step->setStepDescription("WHERE"); steps.emplace_back(filter_step.get()); query_plan->addStep(std::move(filter_step)); - + + // header maybe changed, need to rollback it + if (!blocksHaveEqualStructure(input_header, query_plan->getCurrentDataStream().header)) { + steps.emplace_back(getPlanParser()->addRollbackFilterHeaderStep(query_plan, input_header)); + } + // remove nullable auto * remove_null_step = getPlanParser()->addRemoveNullableStep(*query_plan, non_nullable_columns); if (remove_null_step) diff --git a/cpp-ch/local-engine/Parser/JoinRelParser.cpp b/cpp-ch/local-engine/Parser/JoinRelParser.cpp index 937e449b0825..58b156c3cf6e 100644 --- a/cpp-ch/local-engine/Parser/JoinRelParser.cpp +++ b/cpp-ch/local-engine/Parser/JoinRelParser.cpp @@ -459,7 +459,7 @@ void JoinRelParser::addConvertStep(TableJoin & table_join, DB::QueryPlan & left, rename_dag->getOutputs()[pos] = &alias; } } - rename_dag->projectInput(); + QueryPlanStepPtr project_step = std::make_unique(right.getCurrentDataStream(), rename_dag); project_step->setStepDescription("Right Table Rename"); steps.emplace_back(project_step.get()); diff --git a/cpp-ch/local-engine/Parser/MergeTreeRelParser.cpp b/cpp-ch/local-engine/Parser/MergeTreeRelParser.cpp index c36db6b7484a..b51b76b97415 100644 --- a/cpp-ch/local-engine/Parser/MergeTreeRelParser.cpp +++ b/cpp-ch/local-engine/Parser/MergeTreeRelParser.cpp @@ -211,7 +211,7 @@ PrewhereInfoPtr MergeTreeRelParser::parsePreWhereInfo(const substrait::Expressio prewhere_info->prewhere_column_name = filter_name; prewhere_info->need_filter = true; prewhere_info->remove_prewhere_column = true; - prewhere_info->prewhere_actions->projectInput(false); + for (const auto & name : input.getNames()) prewhere_info->prewhere_actions->tryRestoreColumn(name); return prewhere_info; diff --git a/cpp-ch/local-engine/Parser/ProjectRelParser.cpp b/cpp-ch/local-engine/Parser/ProjectRelParser.cpp index caf779ac13bc..eb190101f170 100644 --- a/cpp-ch/local-engine/Parser/ProjectRelParser.cpp +++ b/cpp-ch/local-engine/Parser/ProjectRelParser.cpp @@ -99,7 +99,6 @@ ProjectRelParser::SplittedActionsDAGs ProjectRelParser::splitActionsDAGInGenerat std::unordered_set first_split_nodes(array_join_node->children.begin(), array_join_node->children.end()); auto first_split_result = actions_dag->split(first_split_nodes); res.before_array_join = first_split_result.first; - res.before_array_join->projectInput(true); array_join_node = findArrayJoinNode(first_split_result.second); std::unordered_set second_split_nodes = {array_join_node}; diff --git a/cpp-ch/local-engine/Parser/RelParser.cpp b/cpp-ch/local-engine/Parser/RelParser.cpp index 7fc807827109..282339c4d641 100644 --- a/cpp-ch/local-engine/Parser/RelParser.cpp +++ b/cpp-ch/local-engine/Parser/RelParser.cpp @@ -15,12 +15,16 @@ * limitations under the License. */ #include "RelParser.h" + #include +#include + #include +#include #include -#include -#include #include +#include + namespace DB { @@ -38,7 +42,20 @@ AggregateFunctionPtr RelParser::getAggregateFunction( { auto & factory = AggregateFunctionFactory::instance(); auto action = NullsAction::EMPTY; - return factory.get(name, action, arg_types, parameters, properties); + + String function_name = name; + if (name == "avg" && isDecimal(removeNullable(arg_types[0]))) + function_name = "sparkAvg"; + else if (name == "avgPartialMerge") + { + if (auto agg_func = typeid_cast(arg_types[0].get()); + !agg_func->getArgumentsDataTypes().empty() && isDecimal(removeNullable(agg_func->getArgumentsDataTypes()[0]))) + { + function_name = "sparkAvgPartialMerge"; + } + } + + return factory.get(function_name, action, arg_types, parameters, properties); } std::optional RelParser::parseSignatureFunctionName(UInt32 function_ref) diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp index 5f2c9cc33150..325ec32dc65f 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp @@ -87,14 +87,14 @@ namespace DB { namespace ErrorCodes { - extern const int LOGICAL_ERROR; - extern const int UNKNOWN_TYPE; - extern const int BAD_ARGUMENTS; - extern const int NO_SUCH_DATA_PART; - extern const int UNKNOWN_FUNCTION; - extern const int CANNOT_PARSE_PROTOBUF_SCHEMA; - extern const int ILLEGAL_TYPE_OF_ARGUMENT; - extern const int INVALID_JOIN_ON_EXPRESSION; +extern const int LOGICAL_ERROR; +extern const int UNKNOWN_TYPE; +extern const int BAD_ARGUMENTS; +extern const int NO_SUCH_DATA_PART; +extern const int UNKNOWN_FUNCTION; +extern const int CANNOT_PARSE_PROTOBUF_SCHEMA; +extern const int ILLEGAL_TYPE_OF_ARGUMENT; +extern const int INVALID_JOIN_ON_EXPRESSION; } } @@ -144,16 +144,13 @@ void SerializedPlanParser::parseExtensions( if (extension.has_extension_function()) { function_mapping.emplace( - std::to_string(extension.extension_function().function_anchor()), - extension.extension_function().name()); + std::to_string(extension.extension_function().function_anchor()), extension.extension_function().name()); } } } std::shared_ptr SerializedPlanParser::expressionsToActionsDAG( - const std::vector & expressions, - const Block & header, - const Block & read_schema) + const std::vector & expressions, const Block & header, const Block & read_schema) { auto actions_dag = std::make_shared(blockToNameAndTypeList(header)); NamesWithAliases required_columns; @@ -234,6 +231,7 @@ std::shared_ptr SerializedPlanParser::expressionsToActionsDAG( throw Exception(ErrorCodes::BAD_ARGUMENTS, "unsupported projection type {}.", magic_enum::enum_name(expr.rex_type_case())); } actions_dag->project(required_columns); + actions_dag->appendInputsForUnusedColumns(header); return actions_dag; } @@ -258,8 +256,8 @@ std::string getDecimalFunction(const substrait::Type_Decimal & decimal, bool nul bool SerializedPlanParser::isReadRelFromJava(const substrait::ReadRel & rel) { - return rel.has_local_files() && rel.local_files().items().size() == 1 && rel.local_files().items().at(0).uri_file().starts_with( - "iterator"); + return rel.has_local_files() && rel.local_files().items().size() == 1 + && rel.local_files().items().at(0).uri_file().starts_with("iterator"); } bool SerializedPlanParser::isReadFromMergeTree(const substrait::ReadRel & rel) @@ -335,6 +333,19 @@ IQueryPlanStep * SerializedPlanParser::addRemoveNullableStep(QueryPlan & plan, c return step_ptr; } +IQueryPlanStep * SerializedPlanParser::addRollbackFilterHeaderStep(QueryPlanPtr & query_plan, const Block & input_header) +{ + auto convert_actions_dag = ActionsDAG::makeConvertingActions( + query_plan->getCurrentDataStream().header.getColumnsWithTypeAndName(), + input_header.getColumnsWithTypeAndName(), + ActionsDAG::MatchColumnsMode::Name); + auto expression_step = std::make_unique(query_plan->getCurrentDataStream(), convert_actions_dag); + expression_step->setStepDescription("Generator for rollback filter"); + auto * step_ptr = expression_step.get(); + query_plan->addStep(std::move(expression_step)); + return step_ptr; +} + DataTypePtr wrapNullableType(substrait::Type_Nullability nullable, DataTypePtr nested_type) { return wrapNullableType(nullable == substrait::Type_Nullability_NULLABILITY_NULLABLE, nested_type); @@ -366,13 +377,13 @@ DataTypePtr wrapNullableType(bool nullable, DataTypePtr nested_type) return nested_type; } -QueryPlanPtr SerializedPlanParser::parse(std::unique_ptr plan) +QueryPlanPtr SerializedPlanParser::parse(const substrait::Plan & plan) { - logDebugMessage(*plan, "substrait plan"); - parseExtensions(plan->extensions()); - if (plan->relations_size() == 1) + logDebugMessage(plan, "substrait plan"); + parseExtensions(plan.extensions()); + if (plan.relations_size() == 1) { - auto root_rel = plan->relations().at(0); + auto root_rel = plan.relations().at(0); if (!root_rel.has_root()) { throw Exception(ErrorCodes::BAD_ARGUMENTS, "must have root rel!"); @@ -573,9 +584,7 @@ SerializedPlanParser::getFunctionName(const std::string & function_signature, co { if (args.size() != 2) throw Exception( - ErrorCodes::BAD_ARGUMENTS, - "Spark function extract requires two args, function:{}", - function.ShortDebugString()); + ErrorCodes::BAD_ARGUMENTS, "Spark function extract requires two args, function:{}", function.ShortDebugString()); // Get the first arg: field const auto & extract_field = args.at(0); @@ -655,19 +664,6 @@ SerializedPlanParser::getFunctionName(const std::string & function_signature, co else ch_function_name = "reverseUTF8"; } - else if (function_name == "concat") - { - /// 1. ConcatOverloadResolver cannot build arrayConcat for Nullable(Array) type which causes failures when using functions like concat(split()). - /// So we use arrayConcat directly if the output type is array. - /// 2. CH ConcatImpl can only accept at least 2 arguments, but Spark concat can accept 1 argument, like concat('a') - /// in such case we use identity function - if (function.output_type().has_list()) - ch_function_name = "arrayConcat"; - else if (args.size() == 1) - ch_function_name = "identity"; - else - ch_function_name = "concat"; - } else ch_function_name = SCALAR_FUNCTIONS.at(function_name); @@ -691,9 +687,7 @@ void SerializedPlanParser::parseArrayJoinArguments( /// The argument number of arrayJoin(converted from Spark explode/posexplode) should be 1 if (scalar_function.arguments_size() != 1) throw Exception( - ErrorCodes::BAD_ARGUMENTS, - "Argument number of arrayJoin should be 1 instead of {}", - scalar_function.arguments_size()); + ErrorCodes::BAD_ARGUMENTS, "Argument number of arrayJoin should be 1 instead of {}", scalar_function.arguments_size()); auto function_name_copy = function_name; parseFunctionArguments(actions_dag, parsed_args, function_name_copy, scalar_function); @@ -732,11 +726,7 @@ void SerializedPlanParser::parseArrayJoinArguments( } ActionsDAG::NodeRawConstPtrs SerializedPlanParser::parseArrayJoinWithDAG( - const substrait::Expression & rel, - std::vector & result_names, - ActionsDAGPtr actions_dag, - bool keep_result, - bool position) + const substrait::Expression & rel, std::vector & result_names, ActionsDAGPtr actions_dag, bool keep_result, bool position) { if (!rel.has_scalar_function()) throw Exception(ErrorCodes::BAD_ARGUMENTS, "The root of expression should be a scalar function:\n {}", rel.DebugString()); @@ -760,7 +750,8 @@ ActionsDAG::NodeRawConstPtrs SerializedPlanParser::parseArrayJoinWithDAG( auto tuple_element_builder = FunctionFactory::instance().get("sparkTupleElement", context); auto tuple_index_type = std::make_shared(); - auto add_tuple_element = [&](const ActionsDAG::Node * tuple_node, size_t i) -> const ActionsDAG::Node * { + auto add_tuple_element = [&](const ActionsDAG::Node * tuple_node, size_t i) -> const ActionsDAG::Node * + { ColumnWithTypeAndName index_col(tuple_index_type->createColumnConst(1, i), tuple_index_type, getUniqueName(std::to_string(i))); const auto * index_node = &actions_dag->addColumn(std::move(index_col)); auto result_name = "sparkTupleElement(" + tuple_node->result_name + ", " + index_node->result_name + ")"; @@ -852,10 +843,7 @@ ActionsDAG::NodeRawConstPtrs SerializedPlanParser::parseArrayJoinWithDAG( } const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( - const substrait::Expression & rel, - std::string & result_name, - ActionsDAGPtr actions_dag, - bool keep_result) + const substrait::Expression & rel, std::string & result_name, ActionsDAGPtr actions_dag, bool keep_result) { if (!rel.has_scalar_function()) throw Exception(ErrorCodes::BAD_ARGUMENTS, "the root of expression should be a scalar function:\n {}", rel.DebugString()); @@ -870,10 +858,7 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( if (auto func_parser = FunctionParserFactory::instance().tryGet(func_name, this)) { LOG_DEBUG( - &Poco::Logger::get("SerializedPlanParser"), - "parse function {} by function parser: {}", - func_name, - func_parser->getName()); + &Poco::Logger::get("SerializedPlanParser"), "parse function {} by function parser: {}", func_name, func_parser->getName()); const auto * result_node = func_parser->parse(scalar_function, actions_dag); if (keep_result) actions_dag->addOrReplaceInOutputs(*result_node); @@ -942,12 +927,10 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( UInt32 precision = rel.scalar_function().output_type().decimal().precision(); UInt32 scale = rel.scalar_function().output_type().decimal().scale(); auto uint32_type = std::make_shared(); - new_args.emplace_back( - &actions_dag->addColumn( - ColumnWithTypeAndName(uint32_type->createColumnConst(1, precision), uint32_type, getUniqueName(toString(precision))))); - new_args.emplace_back( - &actions_dag->addColumn( - ColumnWithTypeAndName(uint32_type->createColumnConst(1, scale), uint32_type, getUniqueName(toString(scale))))); + new_args.emplace_back(&actions_dag->addColumn( + ColumnWithTypeAndName(uint32_type->createColumnConst(1, precision), uint32_type, getUniqueName(toString(precision))))); + new_args.emplace_back(&actions_dag->addColumn( + ColumnWithTypeAndName(uint32_type->createColumnConst(1, scale), uint32_type, getUniqueName(toString(scale))))); args = std::move(new_args); } else if (startsWith(function_signature, "make_decimal:")) @@ -962,12 +945,10 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( UInt32 precision = rel.scalar_function().output_type().decimal().precision(); UInt32 scale = rel.scalar_function().output_type().decimal().scale(); auto uint32_type = std::make_shared(); - new_args.emplace_back( - &actions_dag->addColumn( - ColumnWithTypeAndName(uint32_type->createColumnConst(1, precision), uint32_type, getUniqueName(toString(precision))))); - new_args.emplace_back( - &actions_dag->addColumn( - ColumnWithTypeAndName(uint32_type->createColumnConst(1, scale), uint32_type, getUniqueName(toString(scale))))); + new_args.emplace_back(&actions_dag->addColumn( + ColumnWithTypeAndName(uint32_type->createColumnConst(1, precision), uint32_type, getUniqueName(toString(precision))))); + new_args.emplace_back(&actions_dag->addColumn( + ColumnWithTypeAndName(uint32_type->createColumnConst(1, scale), uint32_type, getUniqueName(toString(scale))))); args = std::move(new_args); } @@ -985,9 +966,8 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( actions_dag, function_node, // as stated in isTypeMatched, currently we don't change nullability of the result type - function_node->result_type->isNullable() - ? local_engine::wrapNullableType(true, result_type)->getName() - : local_engine::removeNullable(result_type)->getName(), + function_node->result_type->isNullable() ? local_engine::wrapNullableType(true, result_type)->getName() + : local_engine::removeNullable(result_type)->getName(), function_node->result_name, CastType::accurateOrNull); } @@ -997,9 +977,8 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( actions_dag, function_node, // as stated in isTypeMatched, currently we don't change nullability of the result type - function_node->result_type->isNullable() - ? local_engine::wrapNullableType(true, result_type)->getName() - : local_engine::removeNullable(result_type)->getName(), + function_node->result_type->isNullable() ? local_engine::wrapNullableType(true, result_type)->getName() + : local_engine::removeNullable(result_type)->getName(), function_node->result_name); } } @@ -1145,9 +1124,7 @@ void SerializedPlanParser::parseFunctionArgument( } const ActionsDAG::Node * SerializedPlanParser::parseFunctionArgument( - ActionsDAGPtr & actions_dag, - const std::string & function_name, - const substrait::FunctionArgument & arg) + ActionsDAGPtr & actions_dag, const std::string & function_name, const substrait::FunctionArgument & arg) { const ActionsDAG::Node * res; if (arg.value().has_scalar_function()) @@ -1175,11 +1152,8 @@ std::pair SerializedPlanParser::convertStructFieldType(const } auto type_id = type->getTypeId(); - if (type_id == TypeIndex::UInt8 || type_id == TypeIndex::UInt16 || type_id == TypeIndex::UInt32 - || type_id == TypeIndex::UInt64) - { + if (type_id == TypeIndex::UInt8 || type_id == TypeIndex::UInt16 || type_id == TypeIndex::UInt32 || type_id == TypeIndex::UInt64) return {type, field}; - } UINT_CONVERT(type, field, Int8) UINT_CONVERT(type, field, Int16) UINT_CONVERT(type, field, Int32) @@ -1189,11 +1163,7 @@ std::pair SerializedPlanParser::convertStructFieldType(const } ActionsDAGPtr SerializedPlanParser::parseFunction( - const Block & header, - const substrait::Expression & rel, - std::string & result_name, - ActionsDAGPtr actions_dag, - bool keep_result) + const Block & header, const substrait::Expression & rel, std::string & result_name, ActionsDAGPtr actions_dag, bool keep_result) { if (!actions_dag) actions_dag = std::make_shared(blockToNameAndTypeList(header)); @@ -1203,11 +1173,7 @@ ActionsDAGPtr SerializedPlanParser::parseFunction( } ActionsDAGPtr SerializedPlanParser::parseFunctionOrExpression( - const Block & header, - const substrait::Expression & rel, - std::string & result_name, - ActionsDAGPtr actions_dag, - bool keep_result) + const Block & header, const substrait::Expression & rel, std::string & result_name, ActionsDAGPtr actions_dag, bool keep_result) { if (!actions_dag) actions_dag = std::make_shared(blockToNameAndTypeList(header)); @@ -1289,7 +1255,8 @@ ActionsDAGPtr SerializedPlanParser::parseJsonTuple( = &actions_dag->addFunction(json_extract_builder, {json_expr_node, extract_expr_node}, json_extract_result_name); auto tuple_element_builder = FunctionFactory::instance().get("sparkTupleElement", context); auto tuple_index_type = std::make_shared(); - auto add_tuple_element = [&](const ActionsDAG::Node * tuple_node, size_t i) -> const ActionsDAG::Node * { + auto add_tuple_element = [&](const ActionsDAG::Node * tuple_node, size_t i) -> const ActionsDAG::Node * + { ColumnWithTypeAndName index_col(tuple_index_type->createColumnConst(1, i), tuple_index_type, getUniqueName(std::to_string(i))); const auto * index_node = &actions_dag->addColumn(std::move(index_col)); auto result_name = "sparkTupleElement(" + tuple_node->result_name + ", " + index_node->result_name + ")"; @@ -1514,9 +1481,7 @@ std::pair SerializedPlanParser::parseLiteral(const substrait } default: { throw Exception( - ErrorCodes::UNKNOWN_TYPE, - "Unsupported spark literal type {}", - magic_enum::enum_name(literal.literal_type_case())); + ErrorCodes::UNKNOWN_TYPE, "Unsupported spark literal type {}", magic_enum::enum_name(literal.literal_type_case())); } } return std::make_pair(std::move(type), std::move(field)); @@ -1718,8 +1683,7 @@ substrait::ReadRel::ExtensionTable SerializedPlanParser::parseExtensionTable(con { substrait::ReadRel::ExtensionTable extension_table; google::protobuf::io::CodedInputStream coded_in( - reinterpret_cast(split_info.data()), - static_cast(split_info.size())); + reinterpret_cast(split_info.data()), static_cast(split_info.size())); coded_in.SetRecursionLimit(100000); auto ok = extension_table.ParseFromCodedStream(&coded_in); @@ -1733,8 +1697,7 @@ substrait::ReadRel::LocalFiles SerializedPlanParser::parseLocalFiles(const std:: { substrait::ReadRel::LocalFiles local_files; google::protobuf::io::CodedInputStream coded_in( - reinterpret_cast(split_info.data()), - static_cast(split_info.size())); + reinterpret_cast(split_info.data()), static_cast(split_info.size())); coded_in.SetRecursionLimit(100000); auto ok = local_files.ParseFromCodedStream(&coded_in); @@ -1744,10 +1707,44 @@ substrait::ReadRel::LocalFiles SerializedPlanParser::parseLocalFiles(const std:: return local_files; } +std::unique_ptr SerializedPlanParser::createExecutor(DB::QueryPlanPtr query_plan) +{ + Stopwatch stopwatch; + auto * logger = &Poco::Logger::get("SerializedPlanParser"); + const Settings & settings = context->getSettingsRef(); -QueryPlanPtr SerializedPlanParser::parse(const std::string & plan) + QueryPriorities priorities; + auto query_status = std::make_shared( + context, + "", + context->getClientInfo(), + priorities.insert(static_cast(settings.priority)), + CurrentThread::getGroup(), + IAST::QueryKind::Select, + settings, + 0); + + QueryPlanOptimizationSettings optimization_settings{.optimize_plan = settings.query_plan_enable_optimizations}; + auto pipeline_builder = query_plan->buildQueryPipeline( + optimization_settings, + BuildQueryPipelineSettings{ + .actions_settings + = ExpressionActionsSettings{.can_compile_expressions = true, .min_count_to_compile_expression = 3, .compile_expressions = CompileExpressions::yes}, + .process_list_element = query_status}); + QueryPipeline pipeline = QueryPipelineBuilder::getPipeline(std::move(*pipeline_builder)); + LOG_INFO(logger, "build pipeline {} ms", stopwatch.elapsedMicroseconds() / 1000.0); + + LOG_DEBUG( + logger, "clickhouse plan [optimization={}]:\n{}", settings.query_plan_enable_optimizations, PlanUtil::explainPlan(*query_plan)); + LOG_DEBUG(logger, "clickhouse pipeline:\n{}", QueryPipelineUtil::explainPipeline(pipeline)); + + return std::make_unique( + context, std::move(query_plan), std::move(pipeline), query_plan->getCurrentDataStream().header.cloneEmpty()); +} + +QueryPlanPtr SerializedPlanParser::parse(const std::string_view & plan) { - auto plan_ptr = std::make_unique(); + substrait::Plan s_plan; /// https://stackoverflow.com/questions/52028583/getting-error-parsing-protobuf-data /// Parsing may fail when the number of recursive layers is large. /// Here, set a limit large enough to avoid this problem. @@ -1755,11 +1752,10 @@ QueryPlanPtr SerializedPlanParser::parse(const std::string & plan) google::protobuf::io::CodedInputStream coded_in(reinterpret_cast(plan.data()), static_cast(plan.size())); coded_in.SetRecursionLimit(100000); - auto ok = plan_ptr->ParseFromCodedStream(&coded_in); - if (!ok) + if (!s_plan.ParseFromCodedStream(&coded_in)) throw Exception(ErrorCodes::CANNOT_PARSE_PROTOBUF_SCHEMA, "Parse substrait::Plan from string failed"); - auto res = parse(std::move(plan_ptr)); + auto res = parse(s_plan); #ifndef NDEBUG PlanUtil::checkOuputType(*res); @@ -1774,17 +1770,16 @@ QueryPlanPtr SerializedPlanParser::parse(const std::string & plan) return res; } -QueryPlanPtr SerializedPlanParser::parseJson(const std::string & json_plan) +QueryPlanPtr SerializedPlanParser::parseJson(const std::string_view & json_plan) { - auto plan_ptr = std::make_unique(); - auto s = google::protobuf::util::JsonStringToMessage(absl::string_view(json_plan.c_str()), plan_ptr.get()); + substrait::Plan plan; + auto s = google::protobuf::util::JsonStringToMessage(json_plan, &plan); if (!s.ok()) throw Exception(ErrorCodes::CANNOT_PARSE_PROTOBUF_SCHEMA, "Parse substrait::Plan from json string failed: {}", s.ToString()); - return parse(std::move(plan_ptr)); + return parse(plan); } -SerializedPlanParser::SerializedPlanParser(const ContextPtr & context_) - : context(context_) +SerializedPlanParser::SerializedPlanParser(const ContextPtr & context_) : context(context_) { } @@ -1793,13 +1788,10 @@ ContextMutablePtr SerializedPlanParser::global_context = nullptr; Context::ConfigurationPtr SerializedPlanParser::config = nullptr; void SerializedPlanParser::collectJoinKeys( - const substrait::Expression & condition, - std::vector> & join_keys, - int32_t right_key_start) + const substrait::Expression & condition, std::vector> & join_keys, int32_t right_key_start) { auto condition_name = getFunctionName( - function_mapping.at(std::to_string(condition.scalar_function().function_reference())), - condition.scalar_function()); + function_mapping.at(std::to_string(condition.scalar_function().function_reference())), condition.scalar_function()); if (condition_name == "and") { collectJoinKeys(condition.scalar_function().arguments(0).value(), join_keys, right_key_start); @@ -1818,7 +1810,7 @@ void SerializedPlanParser::collectJoinKeys( } } -ActionsDAGPtr ASTParser::convertToActions(const NamesAndTypesList & name_and_types, const ASTPtr & ast) +ActionsDAG ASTParser::convertToActions(const NamesAndTypesList & name_and_types, const ASTPtr & ast) const { NamesAndTypesList aggregation_keys; ColumnNumbersList aggregation_keys_indexes_list; @@ -1827,9 +1819,9 @@ ActionsDAGPtr ASTParser::convertToActions(const NamesAndTypesList & name_and_typ ActionsMatcher::Data visitor_data( context, size_limits_for_set, - size_t(0), + static_cast(0), name_and_types, - std::make_shared(name_and_types), + ActionsDAG(name_and_types), std::make_shared(), false /* no_subqueries */, false /* no_makeset */, @@ -1849,8 +1841,8 @@ ASTPtr ASTParser::parseToAST(const Names & names, const substrait::Expression & auto substrait_name = function_signature.substr(0, function_signature.find(':')); auto func_parser = FunctionParserFactory::instance().tryGet(substrait_name, plan_parser); - String function_name = func_parser ? func_parser->getName() - : SerializedPlanParser::getFunctionName(function_signature, scalar_function); + String function_name + = func_parser ? func_parser->getName() : SerializedPlanParser::getFunctionName(function_signature, scalar_function); ASTs ast_args; parseFunctionArgumentsToAST(names, scalar_function, ast_args); @@ -1862,9 +1854,7 @@ ASTPtr ASTParser::parseToAST(const Names & names, const substrait::Expression & } void ASTParser::parseFunctionArgumentsToAST( - const Names & names, - const substrait::Expression_ScalarFunction & scalar_function, - ASTs & ast_args) + const Names & names, const substrait::Expression_ScalarFunction & scalar_function, ASTs & ast_args) { const auto & args = scalar_function.arguments(); @@ -2007,12 +1997,12 @@ ASTPtr ASTParser::parseArgumentToAST(const Names & names, const substrait::Expre } } -void SerializedPlanParser::removeNullableForRequiredColumns(const std::set & require_columns, ActionsDAGPtr actions_dag) +void SerializedPlanParser::removeNullableForRequiredColumns( + const std::set & require_columns, const ActionsDAGPtr & actions_dag) const { for (const auto & item : require_columns) { - const auto * require_node = actions_dag->tryFindInOutputs(item); - if (require_node) + if (const auto * require_node = actions_dag->tryFindInOutputs(item)) { auto function_builder = FunctionFactory::instance().get("assumeNotNull", context); ActionsDAG::NodeRawConstPtrs args = {require_node}; @@ -2023,9 +2013,7 @@ void SerializedPlanParser::removeNullableForRequiredColumns(const std::set & columns, - ActionsDAGPtr actions_dag, - std::map & nullable_measure_names) + const std::vector & columns, ActionsDAGPtr actions_dag, std::map & nullable_measure_names) { for (const auto & item : columns) { @@ -2039,6 +2027,33 @@ void SerializedPlanParser::wrapNullable( SharedContextHolder SerializedPlanParser::shared_context; +std::unordered_map LocalExecutor::executors; +std::mutex LocalExecutor::executors_mutex; + +void LocalExecutor::cancelAll() +{ + std::lock_guard lock{executors_mutex}; + + for (auto & [handle, executor] : executors) + executor->asyncCancel(); + + for (auto & [handle, executor] : executors) + executor->waitCancelFinished(); +} + +void LocalExecutor::addExecutor(LocalExecutor * executor) +{ + std::lock_guard lock{executors_mutex}; + Int64 handle = reinterpret_cast(executor); + executors.emplace(handle, executor); +} + +void LocalExecutor::removeExecutor(Int64 handle) +{ + std::lock_guard lock{executors_mutex}; + executors.erase(handle); +} + LocalExecutor::~LocalExecutor() { if (context->getConfigRef().getBool("dump_pipeline", false)) @@ -2051,86 +2066,23 @@ LocalExecutor::~LocalExecutor() } } - -void LocalExecutor::execute(QueryPlanPtr query_plan) -{ - Stopwatch stopwatch; - - const Settings & settings = context->getSettingsRef(); - current_query_plan = std::move(query_plan); - auto * logger = &Poco::Logger::get("LocalExecutor"); - - QueryPriorities priorities; - auto query_status = std::make_shared( - context, - "", - context->getClientInfo(), - priorities.insert(static_cast(settings.priority)), - CurrentThread::getGroup(), - IAST::QueryKind::Select, - settings, - 0); - - QueryPlanOptimizationSettings optimization_settings{.optimize_plan = settings.query_plan_enable_optimizations}; - auto pipeline_builder = current_query_plan->buildQueryPipeline( - optimization_settings, - BuildQueryPipelineSettings{ - .actions_settings - = ExpressionActionsSettings{.can_compile_expressions = true, .min_count_to_compile_expression = 3, - .compile_expressions = CompileExpressions::yes}, - .process_list_element = query_status}); - - LOG_DEBUG(logger, "clickhouse plan after optimization:\n{}", PlanUtil::explainPlan(*current_query_plan)); - query_pipeline = QueryPipelineBuilder::getPipeline(std::move(*pipeline_builder)); - LOG_DEBUG(logger, "clickhouse pipeline:\n{}", QueryPipelineUtil::explainPipeline(query_pipeline)); - auto t_pipeline = stopwatch.elapsedMicroseconds(); - - executor = std::make_unique(query_pipeline); - auto t_executor = stopwatch.elapsedMicroseconds() - t_pipeline; - stopwatch.stop(); - LOG_INFO( - logger, - "build pipeline {} ms; create executor {} ms;", - t_pipeline / 1000.0, - t_executor / 1000.0); - - header = current_query_plan->getCurrentDataStream().header.cloneEmpty(); - ch_column_to_spark_row = std::make_unique(); -} - -std::unique_ptr LocalExecutor::writeBlockToSparkRow(Block & block) +std::unique_ptr LocalExecutor::writeBlockToSparkRow(const Block & block) const { return ch_column_to_spark_row->convertCHColumnToSparkRow(block); } bool LocalExecutor::hasNext() { - bool has_next; - try + size_t columns = currentBlock().columns(); + if (columns == 0 || isConsumed()) { - size_t columns = currentBlock().columns(); - if (columns == 0 || isConsumed()) - { - auto empty_block = header.cloneEmpty(); - setCurrentBlock(empty_block); - has_next = executor->pull(currentBlock()); - produce(); - } - else - { - has_next = true; - } - } - catch (Exception & e) - { - LOG_ERROR( - &Poco::Logger::get("LocalExecutor"), - "LocalExecutor run query plan failed with message: {}. Plan Explained: \n{}", - e.message(), - PlanUtil::explainPlan(*current_query_plan)); - throw; + auto empty_block = header.cloneEmpty(); + setCurrentBlock(empty_block); + bool has_next = executor->pull(currentBlock()); + produce(); + return has_next; } - return has_next; + return true; } SparkRowInfoPtr LocalExecutor::next() @@ -2169,8 +2121,35 @@ Block * LocalExecutor::nextColumnar() void LocalExecutor::cancel() { - if (executor) + asyncCancel(); + waitCancelFinished(); +} + +void LocalExecutor::asyncCancel() +{ + if (executor && !is_cancelled) + { + LOG_INFO(&Poco::Logger::get("LocalExecutor"), "Cancel LocalExecutor {}", reinterpret_cast(this)); executor->cancel(); + } +} + +void LocalExecutor::waitCancelFinished() +{ + if (executor && !is_cancelled) + { + Stopwatch watch; + Chunk chunk; + while (executor->pull(chunk)) + ; + is_cancelled = true; + + LOG_INFO( + &Poco::Logger::get("LocalExecutor"), + "Finish cancel LocalExecutor {}, takes {} ms", + reinterpret_cast(this), + watch.elapsedMilliseconds()); + } } Block & LocalExecutor::getHeader() @@ -2178,12 +2157,17 @@ Block & LocalExecutor::getHeader() return header; } -LocalExecutor::LocalExecutor(ContextPtr context_) - : context(context_) +LocalExecutor::LocalExecutor(const ContextPtr & context_, QueryPlanPtr query_plan, QueryPipeline && pipeline, const Block & header_) + : query_pipeline(std::move(pipeline)) + , executor(std::make_unique(query_pipeline)) + , header(header_) + , context(context_) + , ch_column_to_spark_row(std::make_unique()) + , current_query_plan(std::move(query_plan)) { } -std::string LocalExecutor::dumpPipeline() +std::string LocalExecutor::dumpPipeline() const { const auto & processors = query_pipeline.getProcessors(); for (auto & processor : processors) @@ -2207,12 +2191,8 @@ std::string LocalExecutor::dumpPipeline() } NonNullableColumnsResolver::NonNullableColumnsResolver( - const Block & header_, - SerializedPlanParser & parser_, - const substrait::Expression & cond_rel_) - : header(header_) - , parser(parser_) - , cond_rel(cond_rel_) + const Block & header_, SerializedPlanParser & parser_, const substrait::Expression & cond_rel_) + : header(header_), parser(parser_), cond_rel(cond_rel_) { } @@ -2284,8 +2264,7 @@ void NonNullableColumnsResolver::visitNonNullable(const substrait::Expression & } std::string NonNullableColumnsResolver::safeGetFunctionName( - const std::string & function_signature, - const substrait::Expression_ScalarFunction & function) + const std::string & function_signature, const substrait::Expression_ScalarFunction & function) const { try { diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h b/cpp-ch/local-engine/Parser/SerializedPlanParser.h index ccd5c0fdc4c8..184065836e65 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h @@ -105,14 +105,14 @@ static const std::map SCALAR_FUNCTIONS {"sign", "sign"}, {"radians", "radians"}, {"greatest", "sparkGreatest"}, - {"least", "least"}, + {"least", "sparkLeast"}, {"shiftleft", "bitShiftLeft"}, {"shiftright", "bitShiftRight"}, {"check_overflow", "checkDecimalOverflowSpark"}, {"rand", "randCanonical"}, {"isnan", "isNaN"}, {"bin", "sparkBin"}, - {"rint", "sparkRint"}, + {"rint", "sparkRint"}, /// string functions {"like", "like"}, @@ -127,13 +127,11 @@ static const std::map SCALAR_FUNCTIONS {"trim", ""}, // trimLeft or trimLeftSpark, depends on argument size {"ltrim", ""}, // trimRight or trimRightSpark, depends on argument size {"rtrim", ""}, // trimBoth or trimBothSpark, depends on argument size - {"concat", ""}, /// dummy mapping {"strpos", "positionUTF8"}, {"char_length", "char_length"}, /// Notice: when input argument is binary type, corresponding ch function is length instead of char_length {"replace", "replaceAll"}, {"regexp_replace", "replaceRegexpAll"}, - // {"regexp_extract", "regexpExtract"}, {"regexp_extract_all", "regexpExtractAllSpark"}, {"chr", "char"}, {"rlike", "match"}, @@ -151,7 +149,7 @@ static const std::map SCALAR_FUNCTIONS {"initcap", "initcapUTF8"}, {"conv", "sparkConv"}, {"uuid", "generateUUIDv4"}, - {"levenshteinDistance", "editDistanceUTF8"}, + {"levenshteinDistance", "editDistanceUTF8"}, /// hash functions {"crc32", "CRC32"}, @@ -180,6 +178,7 @@ static const std::map SCALAR_FUNCTIONS {"array", "array"}, {"shuffle", "arrayShuffle"}, {"range", "range"}, /// dummy mapping + {"flatten", "sparkArrayFlatten"}, // map functions {"map", "map"}, @@ -218,6 +217,7 @@ DataTypePtr wrapNullableType(bool nullable, DataTypePtr nested_type); std::string join(const ActionsDAG::NodeRawConstPtrs & v, char c); class SerializedPlanParser; +class LocalExecutor; // Give a condition expression `cond_rel_`, found all columns with nullability that must not containt // null after this filter. @@ -241,7 +241,7 @@ class NonNullableColumnsResolver void visit(const substrait::Expression & expr); void visitNonNullable(const substrait::Expression & expr); - String safeGetFunctionName(const String & function_signature, const substrait::Expression_ScalarFunction & function); + String safeGetFunctionName(const String & function_signature, const substrait::Expression_ScalarFunction & function) const; }; class SerializedPlanParser @@ -257,11 +257,21 @@ class SerializedPlanParser friend class JoinRelParser; friend class MergeTreeRelParser; + std::unique_ptr createExecutor(DB::QueryPlanPtr query_plan); + + DB::QueryPlanPtr parse(const std::string_view & plan); + DB::QueryPlanPtr parse(const substrait::Plan & plan); + public: explicit SerializedPlanParser(const ContextPtr & context); - DB::QueryPlanPtr parse(const std::string & plan); - DB::QueryPlanPtr parseJson(const std::string & json_plan); - DB::QueryPlanPtr parse(std::unique_ptr plan); + + /// UT only + DB::QueryPlanPtr parseJson(const std::string_view & json_plan); + std::unique_ptr createExecutor(const substrait::Plan & plan) { return createExecutor(parse((plan))); } + /// + + template + std::unique_ptr createExecutor(const std::string_view & plan); DB::QueryPlanStepPtr parseReadRealWithLocalFile(const substrait::ReadRel & rel); DB::QueryPlanStepPtr parseReadRealWithJavaIter(const substrait::ReadRel & rel); @@ -278,7 +288,7 @@ class SerializedPlanParser materialize_inputs.emplace_back(materialize_input); } - void addSplitInfo(std::string & split_info) { split_infos.emplace_back(std::move(split_info)); } + void addSplitInfo(std::string && split_info) { split_infos.emplace_back(std::move(split_info)); } int nextSplitInfoIndex() { @@ -299,6 +309,7 @@ class SerializedPlanParser static std::string getFunctionName(const std::string & function_sig, const substrait::Expression_ScalarFunction & function); IQueryPlanStep * addRemoveNullableStep(QueryPlan & plan, const std::set & columns); + IQueryPlanStep * addRollbackFilterHeaderStep(QueryPlanPtr & query_plan, const Block & input_header); static ContextMutablePtr global_context; static Context::ConfigurationPtr config; @@ -371,7 +382,7 @@ class SerializedPlanParser const ActionsDAG::Node * toFunctionNode(ActionsDAGPtr actions_dag, const String & function, const DB::ActionsDAG::NodeRawConstPtrs & args); // remove nullable after isNotNull - void removeNullableForRequiredColumns(const std::set & require_columns, ActionsDAGPtr actions_dag); + void removeNullableForRequiredColumns(const std::set & require_columns, const ActionsDAGPtr & actions_dag) const; std::string getUniqueName(const std::string & name) { return name + "_" + std::to_string(name_no++); } static std::pair parseLiteral(const substrait::Expression_Literal & literal); void wrapNullable( @@ -393,6 +404,12 @@ class SerializedPlanParser const ActionsDAG::Node * addColumn(DB::ActionsDAGPtr actions_dag, const DataTypePtr & type, const Field & field); }; +template +std::unique_ptr SerializedPlanParser::createExecutor(const std::string_view & plan) +{ + return createExecutor(JsonPlan ? parseJson(plan) : parse(plan)); +} + struct SparkBuffer { char * address; @@ -402,11 +419,9 @@ struct SparkBuffer class LocalExecutor : public BlockIterator { public: - LocalExecutor() = default; - explicit LocalExecutor(ContextPtr context); + LocalExecutor(const ContextPtr & context_, QueryPlanPtr query_plan, QueryPipeline && pipeline, const Block & header_); ~LocalExecutor(); - void execute(QueryPlanPtr query_plan); SparkRowInfoPtr next(); Block * nextColumnar(); bool hasNext(); @@ -418,11 +433,19 @@ class LocalExecutor : public BlockIterator RelMetricPtr getMetric() const { return metric; } void setMetric(RelMetricPtr metric_) { metric = metric_; } void setExtraPlanHolder(std::vector & extra_plan_holder_) { extra_plan_holder = std::move(extra_plan_holder_); } + + static void cancelAll(); + static void addExecutor(LocalExecutor * executor); + static void removeExecutor(Int64 handle); + private: - std::unique_ptr writeBlockToSparkRow(DB::Block & block); + std::unique_ptr writeBlockToSparkRow(const DB::Block & block) const; + + void asyncCancel(); + void waitCancelFinished(); /// Dump processor runtime information to log - std::string dumpPipeline(); + std::string dumpPipeline() const; QueryPipeline query_pipeline; std::unique_ptr executor; @@ -430,10 +453,14 @@ class LocalExecutor : public BlockIterator ContextPtr context; std::unique_ptr ch_column_to_spark_row; std::unique_ptr spark_buffer; - DB::QueryPlanPtr current_query_plan; + QueryPlanPtr current_query_plan; RelMetricPtr metric; std::vector extra_plan_holder; + std::atomic is_cancelled{false}; + /// Record all active LocalExecutor in current executor to cancel them when executor receives shutdown command from driver. + static std::unordered_map executors; + static std::mutex executors_mutex; }; @@ -449,7 +476,7 @@ class ASTParser ~ASTParser() = default; ASTPtr parseToAST(const Names & names, const substrait::Expression & rel); - ActionsDAGPtr convertToActions(const NamesAndTypesList & name_and_types, const ASTPtr & ast); + ActionsDAG convertToActions(const NamesAndTypesList & name_and_types, const ASTPtr & ast) const; private: ContextPtr context; diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/concat.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/concat.cpp new file mode 100644 index 000000000000..416fe7741812 --- /dev/null +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/concat.cpp @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ + extern const int BAD_ARGUMENTS; + extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; +} +} + +namespace local_engine +{ + +class FunctionParserConcat : public FunctionParser +{ +public: + explicit FunctionParserConcat(SerializedPlanParser * plan_parser_) : FunctionParser(plan_parser_) {} + ~FunctionParserConcat() override = default; + + static constexpr auto name = "concat"; + + String getName() const override { return name; } + + const ActionsDAG::Node * parse( + const substrait::Expression_ScalarFunction & substrait_func, + ActionsDAGPtr & actions_dag) const override + { + /* + parse concat(args) as: + 1. if output type is array, return arrayConcat(args) + 2. otherwise: + 1) if args is empty, return empty string + 2) if args have size 1, return identity(args[0]) + 3) otherwise return concat(args) + */ + auto args = parseFunctionArguments(substrait_func, "", actions_dag); + const auto & output_type = substrait_func.output_type(); + const ActionsDAG::Node * result_node = nullptr; + if (output_type.has_list()) + { + result_node = toFunctionNode(actions_dag, "arrayConcat", args); + } + else + { + if (args.empty()) + result_node = addColumnToActionsDAG(actions_dag, std::make_shared(), ""); + else if (args.size() == 1) + result_node = toFunctionNode(actions_dag, "identity", args); + else + result_node = toFunctionNode(actions_dag, "concat", args); + } + return convertNodeTypeIfNeeded(substrait_func, result_node, actions_dag); + } +}; + +static FunctionParserRegister register_concat; +} diff --git a/cpp-ch/local-engine/Storages/Mergetree/SparkMergeTreeWriter.cpp b/cpp-ch/local-engine/Storages/Mergetree/SparkMergeTreeWriter.cpp index c1f2391a282c..406f2aaa23df 100644 --- a/cpp-ch/local-engine/Storages/Mergetree/SparkMergeTreeWriter.cpp +++ b/cpp-ch/local-engine/Storages/Mergetree/SparkMergeTreeWriter.cpp @@ -69,11 +69,23 @@ SparkMergeTreeWriter::SparkMergeTreeWriter( , bucket_dir(bucket_dir_) , thread_pool(CurrentMetrics::LocalThread, CurrentMetrics::LocalThreadActive, CurrentMetrics::LocalThreadScheduled, 1, 1, 100000) { + const DB::Settings & settings = context->getSettingsRef(); + merge_after_insert = settings.get(MERGETREE_MERGE_AFTER_INSERT).get(); + insert_without_local_storage = settings.get(MERGETREE_INSERT_WITHOUT_LOCAL_STORAGE).get(); + + Field limit_size_field; + if (settings.tryGet("optimize.minFileSize", limit_size_field)) + merge_min_size = limit_size_field.get() <= 0 ? merge_min_size : limit_size_field.get(); + + Field limit_cnt_field; + if (settings.tryGet("mergetree.max_num_part_per_merge_task", limit_cnt_field)) + merge_limit_parts = limit_cnt_field.get() <= 0 ? merge_limit_parts : limit_cnt_field.get(); + dest_storage = MergeTreeRelParser::parseStorage(merge_tree_table, SerializedPlanParser::global_context); + isRemoteStorage = dest_storage->getStoragePolicy()->getAnyDisk()->isRemote(); - if (dest_storage->getStoragePolicy()->getAnyDisk()->isRemote()) + if (useLocalStorage()) { - isRemoteStorage = true; temp_storage = MergeTreeRelParser::copyToDefaultPolicyStorage(merge_tree_table, SerializedPlanParser::global_context); storage = temp_storage; LOG_DEBUG( @@ -86,22 +98,14 @@ SparkMergeTreeWriter::SparkMergeTreeWriter( metadata_snapshot = storage->getInMemoryMetadataPtr(); header = metadata_snapshot->getSampleBlock(); - const DB::Settings & settings = context->getSettingsRef(); squashing = std::make_unique(header, settings.min_insert_block_size_rows, settings.min_insert_block_size_bytes); if (!partition_dir.empty()) extractPartitionValues(partition_dir, partition_values); +} - Field is_merge; - if (settings.tryGet("mergetree.merge_after_insert", is_merge)) - merge_after_insert = is_merge.get(); - - Field limit_size_field; - if (settings.tryGet("optimize.minFileSize", limit_size_field)) - merge_min_size = limit_size_field.get() <= 0 ? merge_min_size : limit_size_field.get(); - - Field limit_cnt_field; - if (settings.tryGet("mergetree.max_num_part_per_merge_task", limit_cnt_field)) - merge_limit_parts = limit_cnt_field.get() <= 0 ? merge_limit_parts : limit_cnt_field.get(); +bool SparkMergeTreeWriter::useLocalStorage() const +{ + return !insert_without_local_storage && isRemoteStorage; } void SparkMergeTreeWriter::write(const DB::Block & block) @@ -161,7 +165,7 @@ void SparkMergeTreeWriter::manualFreeMemory(size_t before_write_memory) // it may alloc memory in current thread, and free on global thread. // Now, wo have not idea to clear global memory by used spark thread tracker. // So we manually correct the memory usage. - if (!isRemoteStorage) + if (isRemoteStorage && insert_without_local_storage) return; auto disk = storage->getStoragePolicy()->getAnyDisk(); @@ -219,7 +223,7 @@ void SparkMergeTreeWriter::saveMetadata() void SparkMergeTreeWriter::commitPartToRemoteStorageIfNeeded() { - if (!isRemoteStorage) + if (!useLocalStorage()) return; LOG_DEBUG( @@ -289,8 +293,8 @@ void SparkMergeTreeWriter::finalizeMerge() { for (const auto & disk : storage->getDisks()) { - auto full_path = storage->getFullPathOnDisk(disk); - disk->removeRecursive(full_path + "/" + tmp_part); + auto rel_path = storage->getRelativeDataPath() + "/" + tmp_part; + disk->removeRecursive(rel_path); } }); } diff --git a/cpp-ch/local-engine/Storages/Mergetree/SparkMergeTreeWriter.h b/cpp-ch/local-engine/Storages/Mergetree/SparkMergeTreeWriter.h index 2b07521ede3a..13ac22394477 100644 --- a/cpp-ch/local-engine/Storages/Mergetree/SparkMergeTreeWriter.h +++ b/cpp-ch/local-engine/Storages/Mergetree/SparkMergeTreeWriter.h @@ -79,6 +79,7 @@ class SparkMergeTreeWriter void finalizeMerge(); bool chunkToPart(Chunk && chunk); bool blockToPart(Block & block); + bool useLocalStorage() const; CustomStorageMergeTreePtr storage = nullptr; CustomStorageMergeTreePtr dest_storage = nullptr; @@ -97,6 +98,7 @@ class SparkMergeTreeWriter std::unordered_set tmp_parts; DB::Block header; bool merge_after_insert; + bool insert_without_local_storage; FreeThreadPool thread_pool; size_t merge_min_size = 1024 * 1024 * 1024; size_t merge_limit_parts = 10; diff --git a/cpp-ch/local-engine/local_engine_jni.cpp b/cpp-ch/local-engine/local_engine_jni.cpp index 38f188293726..9c642d70ec27 100644 --- a/cpp-ch/local-engine/local_engine_jni.cpp +++ b/cpp-ch/local-engine/local_engine_jni.cpp @@ -36,10 +36,14 @@ #include #include #include +#include +#include #include #include #include +#include #include +#include #include #include #include @@ -51,10 +55,6 @@ #include #include #include -#include -#include -#include -#include #ifdef __cplusplus @@ -224,11 +224,9 @@ JNIEXPORT void JNI_OnUnload(JavaVM * vm, void * /*reserved*/) JNIEXPORT void Java_org_apache_gluten_vectorized_ExpressionEvaluatorJniWrapper_nativeInitNative(JNIEnv * env, jobject, jbyteArray conf_plan) { LOCAL_ENGINE_JNI_METHOD_START - jsize plan_buf_size = env->GetArrayLength(conf_plan); + std::string::size_type plan_buf_size = env->GetArrayLength(conf_plan); jbyte * plan_buf_addr = env->GetByteArrayElements(conf_plan, nullptr); - std::string plan_str; - plan_str.assign(reinterpret_cast(plan_buf_addr), plan_buf_size); - local_engine::BackendInitializerUtil::init(&plan_str); + local_engine::BackendInitializerUtil::init({reinterpret_cast(plan_buf_addr), plan_buf_size}); env->ReleaseByteArrayElements(conf_plan, plan_buf_addr, JNI_ABORT); LOCAL_ENGINE_JNI_METHOD_END(env, ) } @@ -254,11 +252,9 @@ JNIEXPORT jlong Java_org_apache_gluten_vectorized_ExpressionEvaluatorJniWrapper_ auto query_context = local_engine::getAllocator(allocator_id)->query_context; // by task update new configs ( in case of dynamic config update ) - jsize plan_buf_size = env->GetArrayLength(conf_plan); + std::string::size_type plan_buf_size = env->GetArrayLength(conf_plan); jbyte * plan_buf_addr = env->GetByteArrayElements(conf_plan, nullptr); - std::string plan_str; - plan_str.assign(reinterpret_cast(plan_buf_addr), plan_buf_size); - local_engine::BackendInitializerUtil::updateConfig(query_context, &plan_str); + local_engine::BackendInitializerUtil::updateConfig(query_context, {reinterpret_cast(plan_buf_addr), plan_buf_size}); local_engine::SerializedPlanParser parser(query_context); jsize iter_num = env->GetArrayLength(iter_arr); @@ -269,25 +265,22 @@ JNIEXPORT jlong Java_org_apache_gluten_vectorized_ExpressionEvaluatorJniWrapper_ parser.addInputIter(iter, materialize_input); } - for (jsize i = 0, split_info_arr_size = env->GetArrayLength(split_infos); i < split_info_arr_size; i++) { + for (jsize i = 0, split_info_arr_size = env->GetArrayLength(split_infos); i < split_info_arr_size; i++) + { jbyteArray split_info = static_cast(env->GetObjectArrayElement(split_infos, i)); - jsize split_info_size = env->GetArrayLength(split_info); + std::string::size_type split_info_size = env->GetArrayLength(split_info); jbyte * split_info_addr = env->GetByteArrayElements(split_info, nullptr); - std::string split_info_str; - split_info_str.assign(reinterpret_cast(split_info_addr), split_info_size); - parser.addSplitInfo(split_info_str); + parser.addSplitInfo(std::string{reinterpret_cast(split_info_addr), split_info_size}); } - jsize plan_size = env->GetArrayLength(plan); + std::string::size_type plan_size = env->GetArrayLength(plan); jbyte * plan_address = env->GetByteArrayElements(plan, nullptr); - std::string plan_string; - plan_string.assign(reinterpret_cast(plan_address), plan_size); - auto query_plan = parser.parse(plan_string); - local_engine::LocalExecutor * executor = new local_engine::LocalExecutor(query_context); + local_engine::LocalExecutor * executor + = parser.createExecutor({reinterpret_cast(plan_address), plan_size}).release(); + local_engine::LocalExecutor::addExecutor(executor); LOG_INFO(&Poco::Logger::get("jni"), "Construct LocalExecutor {}", reinterpret_cast(executor)); executor->setMetric(parser.getMetric()); executor->setExtraPlanHolder(parser.extra_plan_holder); - executor->execute(std::move(query_plan)); env->ReleaseByteArrayElements(plan, plan_address, JNI_ABORT); env->ReleaseByteArrayElements(conf_plan, plan_buf_addr, JNI_ABORT); return reinterpret_cast(executor); @@ -315,17 +308,19 @@ JNIEXPORT jlong Java_org_apache_gluten_vectorized_BatchIterator_nativeCHNext(JNI JNIEXPORT void Java_org_apache_gluten_vectorized_BatchIterator_nativeCancel(JNIEnv * env, jobject /*obj*/, jlong executor_address) { LOCAL_ENGINE_JNI_METHOD_START + local_engine::LocalExecutor::removeExecutor(executor_address); local_engine::LocalExecutor * executor = reinterpret_cast(executor_address); executor->cancel(); - LOG_INFO(&Poco::Logger::get("jni"), "Cancel LocalExecutor {}", reinterpret_cast(executor)); + LOG_INFO(&Poco::Logger::get("jni"), "Cancel LocalExecutor {}", reinterpret_cast(executor)); LOCAL_ENGINE_JNI_METHOD_END(env, ) } JNIEXPORT void Java_org_apache_gluten_vectorized_BatchIterator_nativeClose(JNIEnv * env, jobject /*obj*/, jlong executor_address) { LOCAL_ENGINE_JNI_METHOD_START + local_engine::LocalExecutor::removeExecutor(executor_address); local_engine::LocalExecutor * executor = reinterpret_cast(executor_address); - LOG_INFO(&Poco::Logger::get("jni"), "Finalize LocalExecutor {}", reinterpret_cast(executor)); + LOG_INFO(&Poco::Logger::get("jni"), "Finalize LocalExecutor {}", reinterpret_cast(executor)); delete executor; LOCAL_ENGINE_JNI_METHOD_END(env, ) } @@ -630,8 +625,7 @@ JNIEXPORT jlong Java_org_apache_gluten_vectorized_CHShuffleSplitterJniWrapper_na .max_sort_buffer_size = static_cast(max_sort_buffer_size), .spill_firstly_before_stop = static_cast(spill_firstly_before_stop), .force_external_sort = static_cast(force_external_sort), - .force_mermory_sort = static_cast(force_memory_sort) - }; + .force_mermory_sort = static_cast(force_memory_sort)}; auto name = jstring2string(env, short_name); local_engine::SplitterHolder * splitter; if (prefer_spill) @@ -696,8 +690,7 @@ JNIEXPORT jlong Java_org_apache_gluten_vectorized_CHShuffleSplitterJniWrapper_na .throw_if_memory_exceed = static_cast(throw_if_memory_exceed), .flush_block_buffer_before_evict = static_cast(flush_block_buffer_before_evict), .force_external_sort = static_cast(force_external_sort), - .force_mermory_sort = static_cast(force_memory_sort) - }; + .force_mermory_sort = static_cast(force_memory_sort)}; auto name = jstring2string(env, short_name); local_engine::SplitterHolder * splitter; splitter = new local_engine::SplitterHolder{.splitter = std::make_unique(name, options, pusher)}; @@ -768,8 +761,8 @@ JNIEXPORT void Java_org_apache_gluten_vectorized_CHShuffleSplitterJniWrapper_clo } // CHBlockConverterJniWrapper -JNIEXPORT jobject -Java_org_apache_gluten_vectorized_CHBlockConverterJniWrapper_convertColumnarToRow(JNIEnv * env, jclass, jlong block_address, jintArray masks) +JNIEXPORT jobject Java_org_apache_gluten_vectorized_CHBlockConverterJniWrapper_convertColumnarToRow( + JNIEnv * env, jclass, jlong block_address, jintArray masks) { LOCAL_ENGINE_JNI_METHOD_START local_engine::CHColumnToSparkRow converter; @@ -932,11 +925,10 @@ JNIEXPORT jlong Java_org_apache_spark_sql_execution_datasources_CHDatasourceJniW LOCAL_ENGINE_JNI_METHOD_START auto query_context = local_engine::getAllocator(allocator_id)->query_context; // by task update new configs ( in case of dynamic config update ) - jsize conf_plan_buf_size = env->GetArrayLength(conf_plan); + std::string::size_type conf_plan_buf_size = env->GetArrayLength(conf_plan); jbyte * conf_plan_buf_addr = env->GetByteArrayElements(conf_plan, nullptr); - std::string conf_plan_str; - conf_plan_str.assign(reinterpret_cast(conf_plan_buf_addr), conf_plan_buf_size); - local_engine::BackendInitializerUtil::updateConfig(query_context, &conf_plan_str); + local_engine::BackendInitializerUtil::updateConfig( + query_context, {reinterpret_cast(conf_plan_buf_addr), conf_plan_buf_size}); const auto uuid_str = jstring2string(env, uuid_); const auto task_id = jstring2string(env, task_id_); @@ -958,21 +950,18 @@ JNIEXPORT jlong Java_org_apache_spark_sql_execution_datasources_CHDatasourceJniW /// Parsing may fail when the number of recursive layers is large. /// Here, set a limit large enough to avoid this problem. /// Once this problem occurs, it is difficult to troubleshoot, because the pb of c++ will not provide any valid information - google::protobuf::io::CodedInputStream coded_in( - reinterpret_cast(plan_str.data()), static_cast(plan_str.size())); + google::protobuf::io::CodedInputStream coded_in(reinterpret_cast(plan_str.data()), static_cast(plan_str.size())); coded_in.SetRecursionLimit(100000); auto ok = plan_ptr->ParseFromCodedStream(&coded_in); if (!ok) throw DB::Exception(DB::ErrorCodes::CANNOT_PARSE_PROTOBUF_SCHEMA, "Parse substrait::Plan from string failed"); - substrait::ReadRel::ExtensionTable extension_table = - local_engine::SerializedPlanParser::parseExtensionTable(split_info_str); + substrait::ReadRel::ExtensionTable extension_table = local_engine::SerializedPlanParser::parseExtensionTable(split_info_str); auto merge_tree_table = local_engine::MergeTreeRelParser::parseMergeTreeTable(extension_table); auto uuid = uuid_str + "_" + task_id; - auto * writer = new local_engine::SparkMergeTreeWriter( - merge_tree_table, query_context, uuid, partition_dir, bucket_dir); + auto * writer = new local_engine::SparkMergeTreeWriter(merge_tree_table, query_context, uuid, partition_dir, bucket_dir); env->ReleaseByteArrayElements(plan_, plan_buf_addr, JNI_ABORT); env->ReleaseByteArrayElements(split_info_, split_info_addr, JNI_ABORT); @@ -1044,8 +1033,8 @@ JNIEXPORT void Java_org_apache_spark_sql_execution_datasources_CHDatasourceJniWr LOCAL_ENGINE_JNI_METHOD_END(env, ) } -JNIEXPORT void -Java_org_apache_spark_sql_execution_datasources_CHDatasourceJniWrapper_writeToMergeTree(JNIEnv * env, jobject, jlong instanceId, jlong block_address) +JNIEXPORT void Java_org_apache_spark_sql_execution_datasources_CHDatasourceJniWrapper_writeToMergeTree( + JNIEnv * env, jobject, jlong instanceId, jlong block_address) { LOCAL_ENGINE_JNI_METHOD_START auto * writer = reinterpret_cast(instanceId); @@ -1054,7 +1043,8 @@ Java_org_apache_spark_sql_execution_datasources_CHDatasourceJniWrapper_writeToMe LOCAL_ENGINE_JNI_METHOD_END(env, ) } -JNIEXPORT jstring Java_org_apache_spark_sql_execution_datasources_CHDatasourceJniWrapper_closeMergeTreeWriter(JNIEnv * env, jobject, jlong instanceId) +JNIEXPORT jstring +Java_org_apache_spark_sql_execution_datasources_CHDatasourceJniWrapper_closeMergeTreeWriter(JNIEnv * env, jobject, jlong instanceId) { LOCAL_ENGINE_JNI_METHOD_START auto * writer = reinterpret_cast(instanceId); @@ -1067,7 +1057,14 @@ JNIEXPORT jstring Java_org_apache_spark_sql_execution_datasources_CHDatasourceJn } JNIEXPORT jstring Java_org_apache_spark_sql_execution_datasources_CHDatasourceJniWrapper_nativeMergeMTParts( - JNIEnv * env, jobject, jbyteArray plan_, jbyteArray split_info_, jstring uuid_, jstring task_id_, jstring partition_dir_, jstring bucket_dir_) + JNIEnv * env, + jobject, + jbyteArray plan_, + jbyteArray split_info_, + jstring uuid_, + jstring task_id_, + jstring partition_dir_, + jstring bucket_dir_) { LOCAL_ENGINE_JNI_METHOD_START @@ -1095,16 +1092,14 @@ JNIEXPORT jstring Java_org_apache_spark_sql_execution_datasources_CHDatasourceJn /// Parsing may fail when the number of recursive layers is large. /// Here, set a limit large enough to avoid this problem. /// Once this problem occurs, it is difficult to troubleshoot, because the pb of c++ will not provide any valid information - google::protobuf::io::CodedInputStream coded_in( - reinterpret_cast(plan_str.data()), static_cast(plan_str.size())); + google::protobuf::io::CodedInputStream coded_in(reinterpret_cast(plan_str.data()), static_cast(plan_str.size())); coded_in.SetRecursionLimit(100000); auto ok = plan_ptr->ParseFromCodedStream(&coded_in); if (!ok) throw DB::Exception(DB::ErrorCodes::CANNOT_PARSE_PROTOBUF_SCHEMA, "Parse substrait::Plan from string failed"); - substrait::ReadRel::ExtensionTable extension_table = - local_engine::SerializedPlanParser::parseExtensionTable(split_info_str); + substrait::ReadRel::ExtensionTable extension_table = local_engine::SerializedPlanParser::parseExtensionTable(split_info_str); google::protobuf::StringValue table; table.ParseFromString(extension_table.detail().value()); auto merge_tree_table = local_engine::parseMergeTreeTableString(table.value()); @@ -1114,12 +1109,12 @@ JNIEXPORT jstring Java_org_apache_spark_sql_execution_datasources_CHDatasourceJn = local_engine::MergeTreeRelParser::copyToVirtualStorage(merge_tree_table, local_engine::SerializedPlanParser::global_context); local_engine::TempStorageFreer freer{temp_storage->getStorageID()}; // to release temp CustomStorageMergeTree with RAII - std::vector selected_parts - = local_engine::StorageMergeTreeFactory::instance().getDataPartsByNames(temp_storage->getStorageID(), "", merge_tree_table.getPartNames()); + std::vector selected_parts = local_engine::StorageMergeTreeFactory::instance().getDataPartsByNames( + temp_storage->getStorageID(), "", merge_tree_table.getPartNames()); std::unordered_map partition_values; - std::vector loaded = - local_engine::mergeParts(selected_parts, partition_values, uuid_str, temp_storage, partition_dir, bucket_dir); + std::vector loaded + = local_engine::mergeParts(selected_parts, partition_values, uuid_str, temp_storage, partition_dir, bucket_dir); std::vector res; for (auto & partPtr : loaded) @@ -1156,7 +1151,8 @@ JNIEXPORT jobject Java_org_apache_spark_sql_execution_datasources_CHDatasourceJn partition_col_indice_vec.push_back(pIndice[i]); env->ReleaseIntArrayElements(partitionColIndice, pIndice, JNI_ABORT); - local_engine::BlockStripes bs = local_engine::BlockStripeSplitter::split(*block, partition_col_indice_vec, hasBucket, reserve_partition_columns); + local_engine::BlockStripes bs + = local_engine::BlockStripeSplitter::split(*block, partition_col_indice_vec, hasBucket, reserve_partition_columns); auto * addresses = env->NewLongArray(bs.block_addresses.size()); @@ -1325,13 +1321,11 @@ Java_org_apache_gluten_vectorized_SimpleExpressionEval_createNativeInstance(JNIE local_engine::SerializedPlanParser parser(context); jobject iter = env->NewGlobalRef(input); parser.addInputIter(iter, false); - jsize plan_size = env->GetArrayLength(plan); + std::string::size_type plan_size = env->GetArrayLength(plan); jbyte * plan_address = env->GetByteArrayElements(plan, nullptr); - std::string plan_string; - plan_string.assign(reinterpret_cast(plan_address), plan_size); - auto query_plan = parser.parse(plan_string); - local_engine::LocalExecutor * executor = new local_engine::LocalExecutor(context); - executor->execute(std::move(query_plan)); + local_engine::LocalExecutor * executor + = parser.createExecutor({reinterpret_cast(plan_address), plan_size}).release(); + local_engine::LocalExecutor::addExecutor(executor); env->ReleaseByteArrayElements(plan, plan_address, JNI_ABORT); return reinterpret_cast(executor); LOCAL_ENGINE_JNI_METHOD_END(env, -1) @@ -1340,6 +1334,7 @@ Java_org_apache_gluten_vectorized_SimpleExpressionEval_createNativeInstance(JNIE JNIEXPORT void Java_org_apache_gluten_vectorized_SimpleExpressionEval_nativeClose(JNIEnv * env, jclass, jlong instance) { LOCAL_ENGINE_JNI_METHOD_START + local_engine::LocalExecutor::removeExecutor(instance); local_engine::LocalExecutor * executor = reinterpret_cast(instance); delete executor; LOCAL_ENGINE_JNI_METHOD_END(env, ) @@ -1366,7 +1361,8 @@ JNIEXPORT jlong Java_org_apache_gluten_memory_alloc_CHNativeMemoryAllocator_getD return -1; } -JNIEXPORT jlong Java_org_apache_gluten_memory_alloc_CHNativeMemoryAllocator_createListenableAllocator(JNIEnv * env, jclass, jobject listener) +JNIEXPORT jlong +Java_org_apache_gluten_memory_alloc_CHNativeMemoryAllocator_createListenableAllocator(JNIEnv * env, jclass, jobject listener) { LOCAL_ENGINE_JNI_METHOD_START auto listener_wrapper = std::make_shared(env->NewGlobalRef(listener)); diff --git a/cpp-ch/local-engine/tests/benchmark_local_engine.cpp b/cpp-ch/local-engine/tests/benchmark_local_engine.cpp index 89fa4fa961ea..208a3b518d45 100644 --- a/cpp-ch/local-engine/tests/benchmark_local_engine.cpp +++ b/cpp-ch/local-engine/tests/benchmark_local_engine.cpp @@ -154,14 +154,11 @@ DB::ContextMutablePtr global_context; std::move(schema)) .build(); local_engine::SerializedPlanParser parser(global_context); - auto query_plan = parser.parse(std::move(plan)); - local_engine::LocalExecutor local_executor; + auto local_executor = parser.createExecutor(*plan); state.ResumeTiming(); - local_executor.execute(std::move(query_plan)); - while (local_executor.hasNext()) - { - local_engine::SparkRowInfoPtr spark_row_info = local_executor.next(); - } + + while (local_executor->hasNext()) + local_engine::SparkRowInfoPtr spark_row_info = local_executor->next(); } } @@ -212,13 +209,12 @@ DB::ContextMutablePtr global_context; std::move(schema)) .build(); local_engine::SerializedPlanParser parser(SerializedPlanParser::global_context); - auto query_plan = parser.parse(std::move(plan)); - local_engine::LocalExecutor local_executor; + auto local_executor = parser.createExecutor(*plan); state.ResumeTiming(); - local_executor.execute(std::move(query_plan)); - while (local_executor.hasNext()) + + while (local_executor->hasNext()) { - Block * block = local_executor.nextColumnar(); + Block * block = local_executor->nextColumnar(); delete block; } } @@ -238,15 +234,10 @@ DB::ContextMutablePtr global_context; std::ifstream t(path); std::string str((std::istreambuf_iterator(t)), std::istreambuf_iterator()); std::cout << "the plan from: " << path << std::endl; - - auto query_plan = parser.parse(str); - local_engine::LocalExecutor local_executor; + auto local_executor = parser.createExecutor(str); state.ResumeTiming(); - local_executor.execute(std::move(query_plan)); - while (local_executor.hasNext()) - { - [[maybe_unused]] auto * x = local_executor.nextColumnar(); - } + while (local_executor->hasNext()) [[maybe_unused]] + auto * x = local_executor->nextColumnar(); } } @@ -282,14 +273,12 @@ DB::ContextMutablePtr global_context; std::move(schema)) .build(); local_engine::SerializedPlanParser parser(SerializedPlanParser::global_context); - auto query_plan = parser.parse(std::move(plan)); - local_engine::LocalExecutor local_executor; + + auto local_executor = parser.createExecutor(*plan); state.ResumeTiming(); - local_executor.execute(std::move(query_plan)); - while (local_executor.hasNext()) - { - local_engine::SparkRowInfoPtr spark_row_info = local_executor.next(); - } + + while (local_executor->hasNext()) + local_engine::SparkRowInfoPtr spark_row_info = local_executor->next(); } } @@ -320,16 +309,13 @@ DB::ContextMutablePtr global_context; .build(); local_engine::SerializedPlanParser parser(SerializedPlanParser::global_context); - auto query_plan = parser.parse(std::move(plan)); - local_engine::LocalExecutor local_executor; - - local_executor.execute(std::move(query_plan)); + auto local_executor = parser.createExecutor(*plan); local_engine::SparkRowToCHColumn converter; - while (local_executor.hasNext()) + while (local_executor->hasNext()) { - local_engine::SparkRowInfoPtr spark_row_info = local_executor.next(); + local_engine::SparkRowInfoPtr spark_row_info = local_executor->next(); state.ResumeTiming(); - auto block = converter.convertSparkRowInfoToCHColumn(*spark_row_info, local_executor.getHeader()); + auto block = converter.convertSparkRowInfoToCHColumn(*spark_row_info, local_executor->getHeader()); state.PauseTiming(); } state.ResumeTiming(); @@ -368,16 +354,13 @@ DB::ContextMutablePtr global_context; std::move(schema)) .build(); local_engine::SerializedPlanParser parser(SerializedPlanParser::global_context); - auto query_plan = parser.parse(std::move(plan)); - local_engine::LocalExecutor local_executor; - - local_executor.execute(std::move(query_plan)); + auto local_executor = parser.createExecutor(*plan); local_engine::SparkRowToCHColumn converter; - while (local_executor.hasNext()) + while (local_executor->hasNext()) { - local_engine::SparkRowInfoPtr spark_row_info = local_executor.next(); + local_engine::SparkRowInfoPtr spark_row_info = local_executor->next(); state.ResumeTiming(); - auto block = converter.convertSparkRowInfoToCHColumn(*spark_row_info, local_executor.getHeader()); + auto block = converter.convertSparkRowInfoToCHColumn(*spark_row_info, local_executor->getHeader()); state.PauseTiming(); } state.ResumeTiming(); @@ -485,12 +468,8 @@ DB::ContextMutablePtr global_context; y.reserve(cnt); for (auto _ : state) - { for (i = 0; i < cnt; i++) - { y[i] = add(x[i], i); - } - } } [[maybe_unused]] static void BM_TestSumInline(benchmark::State & state) @@ -504,12 +483,8 @@ DB::ContextMutablePtr global_context; y.reserve(cnt); for (auto _ : state) - { for (i = 0; i < cnt; i++) - { y[i] = x[i] + i; - } - } } [[maybe_unused]] static void BM_TestPlus(benchmark::State & state) @@ -545,9 +520,7 @@ DB::ContextMutablePtr global_context; block.insert(y); auto executable_function = function->prepare(arguments); for (auto _ : state) - { auto result = executable_function->execute(block.getColumnsWithTypeAndName(), type, rows, false); - } } [[maybe_unused]] static void BM_TestPlusEmbedded(benchmark::State & state) @@ -847,9 +820,7 @@ QueryPlanPtr joinPlan(QueryPlanPtr left, QueryPlanPtr right, String left_key, St ASTPtr rkey = std::make_shared(right_key); join->addOnKeys(lkey, rkey, true); for (const auto & column : join->columnsFromJoinedTable()) - { join->addJoinedColumn(column); - } auto left_keys = left->getCurrentDataStream().header.getNamesAndTypesList(); join->addJoinedColumnsAndCorrectTypes(left_keys, true); @@ -920,7 +891,8 @@ BENCHMARK(BM_ParquetRead)->Unit(benchmark::kMillisecond)->Iterations(10); int main(int argc, char ** argv) { - BackendInitializerUtil::init(nullptr); + std::string empty; + BackendInitializerUtil::init(empty); SCOPE_EXIT({ BackendFinalizerUtil::finalizeGlobally(); }); ::benchmark::Initialize(&argc, argv); diff --git a/cpp-ch/local-engine/tests/gluten_test_util.cpp b/cpp-ch/local-engine/tests/gluten_test_util.cpp index 7fdd32d1661b..0448092b960d 100644 --- a/cpp-ch/local-engine/tests/gluten_test_util.cpp +++ b/cpp-ch/local-engine/tests/gluten_test_util.cpp @@ -62,14 +62,14 @@ ActionsDAGPtr parseFilter(const std::string & filter, const AnotherRowType & nam size_limits_for_set, static_cast(0), name_and_types, - std::make_shared(name_and_types), + ActionsDAG(name_and_types), prepared_sets /* prepared_sets */, false /* no_subqueries */, false /* no_makeset */, false /* only_consts */, info); ActionsVisitor(visitor_data).visit(ast_exp); - return ActionsDAG::buildFilterActionsDAG({visitor_data.getActions()->getOutputs().back()}, node_name_to_input_column); + return ActionsDAG::buildFilterActionsDAG({visitor_data.getActions().getOutputs().back()}, node_name_to_input_column); } const char * get_data_dir() diff --git a/cpp-ch/local-engine/tests/gluten_test_util.h b/cpp-ch/local-engine/tests/gluten_test_util.h index d4c16e9fbbd8..dba4496d6221 100644 --- a/cpp-ch/local-engine/tests/gluten_test_util.h +++ b/cpp-ch/local-engine/tests/gluten_test_util.h @@ -24,6 +24,7 @@ #include #include #include +#include #include using BlockRowType = DB::ColumnsWithTypeAndName; @@ -60,6 +61,23 @@ AnotherRowType readParquetSchema(const std::string & file); DB::ActionsDAGPtr parseFilter(const std::string & filter, const AnotherRowType & name_and_types); +namespace pb_util +{ +template +std::string JsonStringToBinary(const std::string_view & json) +{ + Message message; + std::string binary; + auto s = google::protobuf::util::JsonStringToMessage(json, &message); + if (!s.ok()) + { + const std::string err_msg{s.message()}; + throw std::runtime_error(err_msg); + } + message.SerializeToString(&binary); + return binary; +} +} } inline DB::DataTypePtr BIGINT() diff --git a/cpp-ch/local-engine/tests/gtest_local_engine.cpp b/cpp-ch/local-engine/tests/gtest_local_engine.cpp index 2d1807841041..962bf9def52e 100644 --- a/cpp-ch/local-engine/tests/gtest_local_engine.cpp +++ b/cpp-ch/local-engine/tests/gtest_local_engine.cpp @@ -16,9 +16,12 @@ */ #include #include +#include +#include + #include -#include #include +#include #include #include #include @@ -28,7 +31,6 @@ #include #include #include -#include #include #include #include @@ -84,13 +86,23 @@ TEST(ReadBufferFromFile, seekBackwards) ASSERT_EQ(x, 8); } +INCBIN(resource_embedded_config_json, SOURCE_DIR "/utils/extern-local-engine/tests/json/gtest_local_engine_config.json"); + +namespace DB +{ +void registerOutputFormatParquet(DB::FormatFactory & factory); +} + int main(int argc, char ** argv) { - auto * init = new String("{\"advancedExtensions\":{\"enhancement\":{\"@type\":\"type.googleapis.com/substrait.Expression\",\"literal\":{\"map\":{\"keyValues\":[{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.ch.runtime_config.logger.level\"},\"value\":{\"string\":\"trace\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.ch.runtime_settings.max_bytes_before_external_sort\"},\"value\":{\"string\":\"5368709120\"}},{\"key\":{\"string\":\"spark.hadoop.fs.s3a.endpoint\"},\"value\":{\"string\":\"localhost:9000\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.velox.IOThreads\"},\"value\":{\"string\":\"0\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.ch.runtime_config.hdfs.input_read_timeout\"},\"value\":{\"string\":\"180000\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.ch.runtime_settings.query_plan_enable_optimizations\"},\"value\":{\"string\":\"false\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.ch.worker.id\"},\"value\":{\"string\":\"1\"}},{\"key\":{\"string\":\"spark.memory.offHeap.enabled\"},\"value\":{\"string\":\"true\"}},{\"key\":{\"string\":\"spark.hadoop.fs.s3a.iam.role.session.name\"},\"value\":{\"string\":\"\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.ch.runtime_config.hdfs.input_connect_timeout\"},\"value\":{\"string\":\"180000\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.shuffle.codec\"},\"value\":{\"string\":\"\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.ch.runtime_config.local_engine.settings.log_processors_profiles\"},\"value\":{\"string\":\"true\"}},{\"key\":{\"string\":\"spark.gluten.memory.offHeap.size.in.bytes\"},\"value\":{\"string\":\"10737418240\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.shuffle.codecBackend\"},\"value\":{\"string\":\"\"}},{\"key\":{\"string\":\"spark.sql.orc.compression.codec\"},\"value\":{\"string\":\"snappy\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.ch.runtime_settings.max_bytes_before_external_group_by\"},\"value\":{\"string\":\"5368709120\"}},{\"key\":{\"string\":\"spark.hadoop.input.write.timeout\"},\"value\":{\"string\":\"180000\"}},{\"key\":{\"string\":\"spark.hadoop.fs.s3a.secret.key\"},\"value\":{\"string\":\"\"}},{\"key\":{\"string\":\"spark.hadoop.fs.s3a.access.key\"},\"value\":{\"string\":\"\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.ch.runtime_config.hdfs.dfs_client_log_severity\"},\"value\":{\"string\":\"INFO\"}},{\"key\":{\"string\":\"spark.hadoop.fs.s3a.path.style.access\"},\"value\":{\"string\":\"true\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.ch.runtime_config.timezone\"},\"value\":{\"string\":\"Asia/Shanghai\"}},{\"key\":{\"string\":\"spark.hadoop.input.read.timeout\"},\"value\":{\"string\":\"180000\"}},{\"key\":{\"string\":\"spark.hadoop.fs.s3a.use.instance.credentials\"},\"value\":{\"string\":\"false\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.ch.runtime_settings.output_format_orc_compression_method\"},\"value\":{\"string\":\"snappy\"}},{\"key\":{\"string\":\"spark.hadoop.fs.s3a.iam.role\"},\"value\":{\"string\":\"\"}},{\"key\":{\"string\":\"spark.gluten.memory.task.offHeap.size.in.bytes\"},\"value\":{\"string\":\"10737418240\"}},{\"key\":{\"string\":\"spark.hadoop.input.connect.timeout\"},\"value\":{\"string\":\"180000\"}},{\"key\":{\"string\":\"spark.hadoop.dfs.client.log.severity\"},\"value\":{\"string\":\"INFO\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.velox.SplitPreloadPerDriver\"},\"value\":{\"string\":\"2\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.ch.runtime_config.hdfs.input_write_timeout\"},\"value\":{\"string\":\"180000\"}},{\"key\":{\"string\":\"spark.hadoop.fs.s3a.connection.ssl.enabled\"},\"value\":{\"string\":\"false\"}}]}}}}}"); + BackendInitializerUtil::init(test::pb_util::JsonStringToBinary( + {reinterpret_cast(gresource_embedded_config_jsonData), gresource_embedded_config_jsonSize})); + + auto & factory = FormatFactory::instance(); + DB::registerOutputFormatParquet(factory); - BackendInitializerUtil::init_json(std::move(init)); SCOPE_EXIT({ BackendFinalizerUtil::finalizeGlobally(); }); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); -} +} \ No newline at end of file diff --git a/cpp-ch/local-engine/tests/gtest_parser.cpp b/cpp-ch/local-engine/tests/gtest_parser.cpp new file mode 100644 index 000000000000..485740191ea3 --- /dev/null +++ b/cpp-ch/local-engine/tests/gtest_parser.cpp @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include + + +using namespace local_engine; +using namespace DB; + +// Plan for https://github.com/ClickHouse/ClickHouse/pull/65234 +INCBIN(resource_embedded_pr_65234_json, SOURCE_DIR "/utils/extern-local-engine/tests/json/clickhouse_pr_65234.json"); + +TEST(SerializedPlanParser, PR65234) +{ + const std::string split + = R"({"items":[{"uriFile":"file:///home/chang/SourceCode/rebase_gluten/backends-clickhouse/target/scala-2.12/test-classes/tests-working-home/tpch-data/supplier/part-00000-16caa751-9774-470c-bd37-5c84c53373c8-c000.snappy.parquet","length":"84633","parquet":{},"schema":{},"metadataColumns":[{}]}]})"; + SerializedPlanParser parser(SerializedPlanParser::global_context); + parser.addSplitInfo(test::pb_util::JsonStringToBinary(split)); + auto query_plan + = parser.parseJson({reinterpret_cast(gresource_embedded_pr_65234_jsonData), gresource_embedded_pr_65234_jsonSize}); +} + +#include +#include +#include +#include +#include + +Chunk testChunk() +{ + auto nameCol = STRING()->createColumn(); + nameCol->insert("one"); + nameCol->insert("two"); + nameCol->insert("three"); + + auto valueCol = UINT()->createColumn(); + valueCol->insert(1); + valueCol->insert(2); + valueCol->insert(3); + MutableColumns x; + x.push_back(std::move(nameCol)); + x.push_back(std::move(valueCol)); + return {std::move(x), 3}; +} + +TEST(LocalExecutor, StorageObjectStorageSink) +{ + /// 0. Create ObjectStorage for HDFS + auto settings = SerializedPlanParser::global_context->getSettingsRef(); + const std::string query + = R"(CREATE TABLE hdfs_engine_xxxx (name String, value UInt32) ENGINE=HDFS('hdfs://localhost:8020/clickhouse/test2', 'Parquet'))"; + DB::ParserCreateQuery parser; + std::string error_message; + const char * pos = query.data(); + auto ast = DB::tryParseQuery( + parser, + pos, + pos + query.size(), + error_message, + /* hilite = */ false, + "QUERY TEST", + /* allow_multi_statements = */ false, + 0, + settings.max_parser_depth, + settings.max_parser_backtracks, + true); + auto & create = ast->as(); + auto arg = create.storage->children[0]; + const auto * func = arg->as(); + EXPECT_TRUE(func && func->name == "HDFS"); + + DB::StorageHDFSConfiguration config; + StorageObjectStorage::Configuration::initialize(config, arg->children[0]->children, SerializedPlanParser::global_context, false); + + const std::shared_ptr object_storage + = std::dynamic_pointer_cast(config.createObjectStorage(SerializedPlanParser::global_context, false)); + EXPECT_TRUE(object_storage != nullptr); + + RelativePathsWithMetadata files_with_metadata; + object_storage->listObjects("/clickhouse", files_with_metadata, 0); + + /// 1. Create ObjectStorageSink + DB::StorageObjectStorageSink sink{ + object_storage, config.clone(), {}, {{STRING(), "name"}, {UINT(), "value"}}, SerializedPlanParser::global_context, ""}; + + /// 2. Create Chunk + /// 3. comsume + sink.consume(testChunk()); + sink.onFinish(); +} + +namespace DB +{ +SinkToStoragePtr createFilelinkSink( + const StorageMetadataPtr & metadata_snapshot, + const String & table_name_for_log, + const String & path, + CompressionMethod compression_method, + const std::optional & format_settings, + const String & format_name, + const ContextPtr & context, + int flags); +} + +INCBIN(resource_embedded_readcsv_json, SOURCE_DIR "/utils/extern-local-engine/tests/json/read_student_option_schema.csv.json"); +TEST(LocalExecutor, StorageFileSink) +{ + const std::string split + = R"({"items":[{"uriFile":"file:///home/chang/SourceCode/rebase_gluten/backends-velox/src/test/resources/datasource/csv/student_option_schema.csv","length":"56","text":{"fieldDelimiter":",","maxBlockSize":"8192","header":"1"},"schema":{"names":["id","name","language"],"struct":{"types":[{"string":{"nullability":"NULLABILITY_NULLABLE"}},{"string":{"nullability":"NULLABILITY_NULLABLE"}},{"string":{"nullability":"NULLABILITY_NULLABLE"}}]}},"metadataColumns":[{}]}]})"; + SerializedPlanParser parser(SerializedPlanParser::global_context); + parser.addSplitInfo(test::pb_util::JsonStringToBinary(split)); + auto local_executor = parser.createExecutor( + {reinterpret_cast(gresource_embedded_readcsv_jsonData), gresource_embedded_readcsv_jsonSize}); + + while (local_executor->hasNext()) + { + const Block & x = *local_executor->nextColumnar(); + EXPECT_EQ(4, x.rows()); + } + + StorageInMemoryMetadata metadata; + metadata.setColumns(ColumnsDescription::fromNamesAndTypes({{"name", STRING()}, {"value", UINT()}})); + StorageMetadataPtr metadata_ptr = std::make_shared(metadata); + + auto sink = createFilelinkSink( + metadata_ptr, + "test_table", + "/tmp/test_table.parquet", + CompressionMethod::None, + {}, + "Parquet", + SerializedPlanParser::global_context, + 0); + + sink->consume(testChunk()); + sink->onFinish(); +} \ No newline at end of file diff --git a/cpp-ch/local-engine/tests/json/clickhouse_pr_65234.json b/cpp-ch/local-engine/tests/json/clickhouse_pr_65234.json new file mode 100644 index 000000000000..1c37b68b7144 --- /dev/null +++ b/cpp-ch/local-engine/tests/json/clickhouse_pr_65234.json @@ -0,0 +1,273 @@ +{ + "extensions": [{ + "extensionFunction": { + "functionAnchor": 1, + "name": "is_not_null:str" + } + }, { + "extensionFunction": { + "functionAnchor": 2, + "name": "equal:str_str" + } + }, { + "extensionFunction": { + "functionAnchor": 3, + "name": "is_not_null:i64" + } + }, { + "extensionFunction": { + "name": "and:bool_bool" + } + }], + "relations": [{ + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [2] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["r_regionkey", "r_name"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }] + }, + "columnTypes": ["NORMAL_COL", "NORMAL_COL"] + }, + "filter": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + } + } + } + }, { + "value": { + "literal": { + "string": "EUROPE" + } + } + }] + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + } + } + } + } + }] + } + } + }] + } + }, + "advancedExtension": { + "optimization": { + "@type": "type.googleapis.com/google.protobuf.StringValue", + "value": "isMergeTree\u003d0\n" + } + } + } + }, + "condition": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + } + } + } + }, { + "value": { + "literal": { + "string": "EUROPE" + } + } + }] + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + } + } + } + } + }] + } + } + }] + } + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + } + } + } + }] + } + }, + "names": ["r_regionkey#72"], + "outputSchema": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + } + }] +} \ No newline at end of file diff --git a/cpp-ch/local-engine/tests/json/gtest_local_engine_config.json b/cpp-ch/local-engine/tests/json/gtest_local_engine_config.json new file mode 100644 index 000000000000..10f0ea3dfdad --- /dev/null +++ b/cpp-ch/local-engine/tests/json/gtest_local_engine_config.json @@ -0,0 +1,269 @@ +{ + "advancedExtensions": { + "enhancement": { + "@type": "type.googleapis.com/substrait.Expression", + "literal": { + "map": { + "keyValues": [ + { + "key": { + "string": "spark.gluten.sql.columnar.backend.ch.runtime_config.logger.level" + }, + "value": { + "string": "test" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.backend.ch.runtime_settings.max_bytes_before_external_sort" + }, + "value": { + "string": "5368709120" + } + }, + { + "key": { + "string": "spark.hadoop.fs.s3a.endpoint" + }, + "value": { + "string": "localhost:9000" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.backend.velox.IOThreads" + }, + "value": { + "string": "0" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.backend.ch.runtime_config.hdfs.input_read_timeout" + }, + "value": { + "string": "180000" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.backend.ch.runtime_settings.query_plan_enable_optimizations" + }, + "value": { + "string": "false" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.backend.ch.worker.id" + }, + "value": { + "string": "1" + } + }, + { + "key": { + "string": "spark.memory.offHeap.enabled" + }, + "value": { + "string": "true" + } + }, + { + "key": { + "string": "spark.hadoop.fs.s3a.iam.role.session.name" + }, + "value": { + "string": "" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.backend.ch.runtime_config.hdfs.input_connect_timeout" + }, + "value": { + "string": "180000" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.shuffle.codec" + }, + "value": { + "string": "" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.backend.ch.runtime_config.local_engine.settings.log_processors_profiles" + }, + "value": { + "string": "true" + } + }, + { + "key": { + "string": "spark.gluten.memory.offHeap.size.in.bytes" + }, + "value": { + "string": "10737418240" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.shuffle.codecBackend" + }, + "value": { + "string": "" + } + }, + { + "key": { + "string": "spark.sql.orc.compression.codec" + }, + "value": { + "string": "snappy" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.backend.ch.runtime_settings.max_bytes_before_external_group_by" + }, + "value": { + "string": "5368709120" + } + }, + { + "key": { + "string": "spark.hadoop.input.write.timeout" + }, + "value": { + "string": "180000" + } + }, + { + "key": { + "string": "spark.hadoop.fs.s3a.secret.key" + }, + "value": { + "string": "" + } + }, + { + "key": { + "string": "spark.hadoop.fs.s3a.access.key" + }, + "value": { + "string": "" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.backend.ch.runtime_config.hdfs.dfs_client_log_severity" + }, + "value": { + "string": "INFO" + } + }, + { + "key": { + "string": "spark.hadoop.fs.s3a.path.style.access" + }, + "value": { + "string": "true" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.backend.ch.runtime_config.timezone" + }, + "value": { + "string": "Asia/Shanghai" + } + }, + { + "key": { + "string": "spark.hadoop.input.read.timeout" + }, + "value": { + "string": "180000" + } + }, + { + "key": { + "string": "spark.hadoop.fs.s3a.use.instance.credentials" + }, + "value": { + "string": "false" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.backend.ch.runtime_settings.output_format_orc_compression_method" + }, + "value": { + "string": "snappy" + } + }, + { + "key": { + "string": "spark.hadoop.fs.s3a.iam.role" + }, + "value": { + "string": "" + } + }, + { + "key": { + "string": "spark.gluten.memory.task.offHeap.size.in.bytes" + }, + "value": { + "string": "10737418240" + } + }, + { + "key": { + "string": "spark.hadoop.input.connect.timeout" + }, + "value": { + "string": "180000" + } + }, + { + "key": { + "string": "spark.hadoop.dfs.client.log.severity" + }, + "value": { + "string": "INFO" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.backend.velox.SplitPreloadPerDriver" + }, + "value": { + "string": "2" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.backend.ch.runtime_config.hdfs.input_write_timeout" + }, + "value": { + "string": "180000" + } + }, + { + "key": { + "string": "spark.hadoop.fs.s3a.connection.ssl.enabled" + }, + "value": { + "string": "false" + } + } + ] + } + } + } + } +} \ No newline at end of file diff --git a/cpp-ch/local-engine/tests/json/read_student_option_schema.csv.json b/cpp-ch/local-engine/tests/json/read_student_option_schema.csv.json new file mode 100644 index 000000000000..f9518d39014a --- /dev/null +++ b/cpp-ch/local-engine/tests/json/read_student_option_schema.csv.json @@ -0,0 +1,77 @@ +{ + "relations": [ + { + "root": { + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "id", + "name", + "language" + ], + "struct": { + "types": [ + { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ] + }, + "columnTypes": [ + "NORMAL_COL", + "NORMAL_COL", + "NORMAL_COL" + ] + }, + "advancedExtension": { + "optimization": { + "@type": "type.googleapis.com/google.protobuf.StringValue", + "value": "isMergeTree=0\n" + } + } + } + }, + "names": [ + "id#20", + "name#21", + "language#22" + ], + "outputSchema": { + "types": [ + { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + } + } + ] +} \ No newline at end of file diff --git a/cpp/CMake/Buildjemalloc_pic.cmake b/cpp/CMake/Buildjemalloc_pic.cmake new file mode 100644 index 000000000000..7c2316ea9540 --- /dev/null +++ b/cpp/CMake/Buildjemalloc_pic.cmake @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# Building Jemalloc +macro(build_jemalloc) + message(STATUS "Building Jemalloc from Source") + + if(DEFINED ENV{GLUTEN_JEMALLOC_URL}) + set(JEMALLOC_SOURCE_URL "$ENV{GLUTEN_JEMALLOC_URL}") + else() + set(JEMALLOC_BUILD_VERSION "5.2.1") + set(JEMALLOC_SOURCE_URL + "https://github.com/jemalloc/jemalloc/releases/download/${JEMALLOC_BUILD_VERSION}/jemalloc-${JEMALLOC_BUILD_VERSION}.tar.bz2" + "https://github.com/ursa-labs/thirdparty/releases/download/latest/jemalloc-${JEMALLOC_BUILD_VERSION}.tar.bz2" + ) + endif() + + set(JEMALLOC_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/jemalloc_ep-install") + set(JEMALLOC_LIB_DIR "${JEMALLOC_PREFIX}/lib") + set(JEMALLOC_INCLUDE_DIR "${JEMALLOC_PREFIX}/include") + set(JEMALLOC_STATIC_LIB + "${JEMALLOC_LIB_DIR}/${CMAKE_STATIC_LIBRARY_PREFIX}jemalloc_pic${CMAKE_STATIC_LIBRARY_SUFFIX}" + ) + set(JEMALLOC_INCLUDE "${JEMALLOC_PREFIX}/include") + set(JEMALLOC_CONFIGURE_ARGS + "AR=${CMAKE_AR}" + "CC=${CMAKE_C_COMPILER}" + "--prefix=${JEMALLOC_PREFIX}" + "--libdir=${JEMALLOC_LIB_DIR}" + "--with-jemalloc-prefix=je_gluten_" + "--with-private-namespace=je_gluten_private_" + "--without-export" + "--disable-shared" + "--disable-cxx" + "--disable-libdl" + # For fixing an issue when loading native lib: cannot allocate memory in + # static TLS block. + "--disable-initial-exec-tls" + "CFLAGS=-fPIC" + "CXXFLAGS=-fPIC") + set(JEMALLOC_BUILD_COMMAND ${MAKE} ${MAKE_BUILD_ARGS}) + ExternalProject_Add( + jemalloc_ep + URL ${JEMALLOC_SOURCE_URL} + PATCH_COMMAND touch doc/jemalloc.3 doc/jemalloc.html + CONFIGURE_COMMAND "./configure" ${JEMALLOC_CONFIGURE_ARGS} + BUILD_COMMAND ${JEMALLOC_BUILD_COMMAND} + BUILD_IN_SOURCE 1 + BUILD_BYPRODUCTS "${JEMALLOC_STATIC_LIB}" + INSTALL_COMMAND make install) + + file(MAKE_DIRECTORY "${JEMALLOC_INCLUDE_DIR}") + add_library(jemalloc::libjemalloc STATIC IMPORTED) + set_target_properties( + jemalloc::libjemalloc + PROPERTIES INTERFACE_LINK_LIBRARIES Threads::Threads + IMPORTED_LOCATION "${JEMALLOC_STATIC_LIB}" + INTERFACE_INCLUDE_DIRECTORIES "${JEMALLOC_INCLUDE_DIR}") + add_dependencies(jemalloc::libjemalloc jemalloc_ep) +endmacro() diff --git a/cpp/CMake/Findjemalloc_pic.cmake b/cpp/CMake/Findjemalloc_pic.cmake index fae9f0d7ad80..ca7b7d213dfc 100644 --- a/cpp/CMake/Findjemalloc_pic.cmake +++ b/cpp/CMake/Findjemalloc_pic.cmake @@ -17,67 +17,25 @@ # Find Jemalloc macro(find_jemalloc) - # Find the existing Protobuf + # Find the existing jemalloc set(CMAKE_FIND_LIBRARY_SUFFIXES ".a") - find_package(jemalloc_pic) - if("${Jemalloc_LIBRARY}" STREQUAL "Jemalloc_LIBRARY-NOTFOUND") - message(FATAL_ERROR "Jemalloc Library Not Found") - endif() - set(PROTOC_BIN ${Jemalloc_PROTOC_EXECUTABLE}) -endmacro() - -# Building Jemalloc -macro(build_jemalloc) - message(STATUS "Building Jemalloc from Source") - - if(DEFINED ENV{GLUTEN_JEMALLOC_URL}) - set(JEMALLOC_SOURCE_URL "$ENV{GLUTEN_JEMALLOC_URL}") + # Find from vcpkg-installed lib path. + find_library( + JEMALLOC_LIBRARY + NAMES jemalloc_pic + PATHS + ${CMAKE_CURRENT_BINARY_DIR}/../../../dev/vcpkg/vcpkg_installed/x64-linux-avx/lib/ + NO_DEFAULT_PATH) + if("${JEMALLOC_LIBRARY}" STREQUAL "JEMALLOC_LIBRARY-NOTFOUND") + message(STATUS "Jemalloc Library Not Found.") + set(JEMALLOC_NOT_FOUND TRUE) else() - set(JEMALLOC_BUILD_VERSION "5.2.1") - set(JEMALLOC_SOURCE_URL - "https://github.com/jemalloc/jemalloc/releases/download/${JEMALLOC_BUILD_VERSION}/jemalloc-${JEMALLOC_BUILD_VERSION}.tar.bz2" - "https://github.com/ursa-labs/thirdparty/releases/download/latest/jemalloc-${JEMALLOC_BUILD_VERSION}.tar.bz2" - ) + message(STATUS "Found jemalloc: ${JEMALLOC_LIBRARY}") + find_path(JEMALLOC_INCLUDE_DIR jemalloc/jemalloc.h) + add_library(jemalloc::libjemalloc STATIC IMPORTED) + set_target_properties( + jemalloc::libjemalloc + PROPERTIES INTERFACE_INCLUDE_DIRECTORIES "${JEMALLOC_INCLUDE_DIR}" + IMPORTED_LOCATION "${JEMALLOC_LIBRARY}") endif() - - set(JEMALLOC_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/jemalloc_ep-install") - set(JEMALLOC_LIB_DIR "${JEMALLOC_PREFIX}/lib") - set(JEMALLOC_INCLUDE_DIR "${JEMALLOC_PREFIX}/include") - set(JEMALLOC_STATIC_LIB - "${JEMALLOC_LIB_DIR}/${CMAKE_STATIC_LIBRARY_PREFIX}jemalloc_pic${CMAKE_STATIC_LIBRARY_SUFFIX}" - ) - set(JEMALLOC_INCLUDE "${JEMALLOC_PREFIX}/include") - set(JEMALLOC_CONFIGURE_ARGS - "AR=${CMAKE_AR}" - "CC=${CMAKE_C_COMPILER}" - "--prefix=${JEMALLOC_PREFIX}" - "--libdir=${JEMALLOC_LIB_DIR}" - "--with-jemalloc-prefix=je_gluten_" - "--with-private-namespace=je_gluten_private_" - "--without-export" - "--disable-shared" - "--disable-cxx" - "--disable-libdl" - "--disable-initial-exec-tls" - "CFLAGS=-fPIC" - "CXXFLAGS=-fPIC") - set(JEMALLOC_BUILD_COMMAND ${MAKE} ${MAKE_BUILD_ARGS}) - ExternalProject_Add( - jemalloc_ep - URL ${JEMALLOC_SOURCE_URL} - PATCH_COMMAND touch doc/jemalloc.3 doc/jemalloc.html - CONFIGURE_COMMAND "./configure" ${JEMALLOC_CONFIGURE_ARGS} - BUILD_COMMAND ${JEMALLOC_BUILD_COMMAND} - BUILD_IN_SOURCE 1 - BUILD_BYPRODUCTS "${JEMALLOC_STATIC_LIB}" - INSTALL_COMMAND make install) - - file(MAKE_DIRECTORY "${JEMALLOC_INCLUDE_DIR}") - add_library(jemalloc::libjemalloc STATIC IMPORTED) - set_target_properties( - jemalloc::libjemalloc - PROPERTIES INTERFACE_LINK_LIBRARIES Threads::Threads - IMPORTED_LOCATION "${JEMALLOC_STATIC_LIB}" - INTERFACE_INCLUDE_DIRECTORIES "${JEMALLOC_INCLUDE_DIR}") - add_dependencies(jemalloc::libjemalloc protobuf_ep) endmacro() diff --git a/cpp/core/CMakeLists.txt b/cpp/core/CMakeLists.txt index 4d7c30402985..e17d13581105 100644 --- a/cpp/core/CMakeLists.txt +++ b/cpp/core/CMakeLists.txt @@ -300,16 +300,6 @@ target_include_directories( set_target_properties(gluten PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${root_directory}/releases) -include(Findjemalloc_pic) -# Build Jemalloc -if(BUILD_JEMALLOC) - build_jemalloc(${STATIC_JEMALLOC}) - message(STATUS "Building Jemalloc: ${STATIC_JEMALLOC}") -else() # - find_jemalloc() - message(STATUS "Use existing Jemalloc libraries") -endif() - if(BUILD_TESTS) add_subdirectory(tests) endif() diff --git a/cpp/velox/CMakeLists.txt b/cpp/velox/CMakeLists.txt index c2d690a7e055..716a5f68a91c 100644 --- a/cpp/velox/CMakeLists.txt +++ b/cpp/velox/CMakeLists.txt @@ -576,6 +576,17 @@ find_package(Folly REQUIRED CONFIG) target_include_directories(velox PUBLIC ${GTEST_INCLUDE_DIRS} ${PROTOBUF_INCLUDE}) +if(BUILD_JEMALLOC) + include(Findjemalloc_pic) + find_jemalloc() + if(JEMALLOC_NOT_FOUND) + include(Buildjemalloc_pic) + build_jemalloc() + endif() + add_definitions(-DENABLE_JEMALLOC) + target_link_libraries(velox PUBLIC jemalloc::libjemalloc) +endif() + target_link_libraries(velox PUBLIC gluten) add_velox_dependencies() diff --git a/cpp/velox/memory/VeloxMemoryManager.cc b/cpp/velox/memory/VeloxMemoryManager.cc index 60c79ffe8725..efd165b736be 100644 --- a/cpp/velox/memory/VeloxMemoryManager.cc +++ b/cpp/velox/memory/VeloxMemoryManager.cc @@ -16,6 +16,10 @@ */ #include "VeloxMemoryManager.h" +#ifdef ENABLE_JEMALLOC +#include +#endif + #include "velox/common/memory/MallocAllocator.h" #include "velox/common/memory/MemoryPool.h" #include "velox/exec/MemoryReclaimer.h" @@ -74,7 +78,7 @@ class ListenableArbitrator : public velox::memory::MemoryArbitrator { uint64_t targetBytes, bool allowSpill, bool allowAbort) override { - velox::memory::ScopedMemoryArbitrationContext ctx(nullptr); + velox::memory::ScopedMemoryArbitrationContext ctx((const velox::memory::MemoryPool*)nullptr); facebook::velox::exec::MemoryReclaimer::Stats status; VELOX_CHECK_EQ(pools.size(), 1, "Gluten only has one root pool"); std::lock_guard l(mutex_); // FIXME: Do we have recursive locking for this mutex? @@ -326,6 +330,9 @@ VeloxMemoryManager::~VeloxMemoryManager() { usleep(waitMs * 1000); accumulatedWaitMs += waitMs; } +#ifdef ENABLE_JEMALLOC + je_gluten_malloc_stats_print(NULL, NULL, NULL); +#endif } } // namespace gluten diff --git a/cpp/velox/operators/functions/RegistrationAllFunctions.cc b/cpp/velox/operators/functions/RegistrationAllFunctions.cc index b827690d1cdf..6b6564fa4aa3 100644 --- a/cpp/velox/operators/functions/RegistrationAllFunctions.cc +++ b/cpp/velox/operators/functions/RegistrationAllFunctions.cc @@ -26,7 +26,6 @@ #include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h" #include "velox/functions/prestosql/registration/RegistrationFunctions.h" #include "velox/functions/prestosql/window/WindowFunctionsRegistration.h" -#include "velox/functions/sparksql/Bitwise.h" #include "velox/functions/sparksql/Hash.h" #include "velox/functions/sparksql/Rand.h" #include "velox/functions/sparksql/Register.h" @@ -35,6 +34,14 @@ using namespace facebook; +namespace facebook::velox::functions { +void registerPrestoVectorFunctions() { + // Presto function. To be removed. + VELOX_REGISTER_VECTOR_FUNCTION(udf_arrays_overlap, "arrays_overlap"); + VELOX_REGISTER_VECTOR_FUNCTION(udf_transform_keys, "transform_keys"); + VELOX_REGISTER_VECTOR_FUNCTION(udf_transform_values, "transform_values"); +} +} // namespace facebook::velox::functions namespace gluten { namespace { void registerFunctionOverwrite() { @@ -45,9 +52,6 @@ void registerFunctionOverwrite() { velox::registerFunction({"round"}); velox::registerFunction({"round"}); velox::registerFunction({"round"}); - // TODO: the below rand function registry can be removed after presto function registry is removed. - velox::registerFunction>({"spark_rand"}); - velox::registerFunction>({"spark_rand"}); auto kRowConstructorWithNull = RowConstructorWithNullCallToSpecialForm::kRowConstructorWithNull; velox::exec::registerVectorFunction( @@ -67,19 +71,12 @@ void registerFunctionOverwrite() { velox::exec::registerFunctionCallToSpecialForm( kRowConstructorWithAllNull, std::make_unique(kRowConstructorWithAllNull)); - velox::functions::sparksql::registerBitwiseFunctions("spark_"); - velox::functions::registerBinaryIntegral({"check_add"}); - velox::functions::registerBinaryIntegral({"check_subtract"}); - velox::functions::registerBinaryIntegral({"check_multiply"}); - velox::functions::registerBinaryIntegral({"check_divide"}); + + velox::functions::registerPrestoVectorFunctions(); } } // namespace void registerAllFunctions() { - // The registration order matters. Spark sql functions are registered after - // presto sql functions to overwrite the registration for same named - // functions. - velox::functions::prestosql::registerAllScalarFunctions(); velox::functions::sparksql::registerFunctions(""); velox::aggregate::prestosql::registerAllAggregateFunctions( "", true /*registerCompanionFunctions*/, false /*onlyPrestoSignatures*/, true /*overwrite*/); diff --git a/cpp/velox/substrait/SubstraitParser.cc b/cpp/velox/substrait/SubstraitParser.cc index 5555ecfef954..b842914ca933 100644 --- a/cpp/velox/substrait/SubstraitParser.cc +++ b/cpp/velox/substrait/SubstraitParser.cc @@ -391,23 +391,13 @@ std::unordered_map SubstraitParser::substraitVeloxFunc {"named_struct", "row_constructor"}, {"bit_or", "bitwise_or_agg"}, {"bit_and", "bitwise_and_agg"}, - {"bitwise_and", "spark_bitwise_and"}, - {"bitwise_not", "spark_bitwise_not"}, - {"bitwise_or", "spark_bitwise_or"}, - {"bitwise_xor", "spark_bitwise_xor"}, - // TODO: the below registry for rand functions can be removed - // after presto function registry is removed. - {"rand", "spark_rand"}, {"murmur3hash", "hash_with_seed"}, {"xxhash64", "xxhash64_with_seed"}, {"modulus", "remainder"}, {"date_format", "format_datetime"}, {"collect_set", "set_agg"}, - {"forall", "all_match"}, - {"exists", "any_match"}, {"negative", "unaryminus"}, - {"get_array_item", "get"}, - {"arrays_zip", "zip"}}; + {"get_array_item", "get"}}; const std::unordered_map SubstraitParser::typeMap_ = { {"bool", "BOOLEAN"}, diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.cc b/cpp/velox/substrait/SubstraitToVeloxPlan.cc index 8b8a9262403c..73047b2f4907 100644 --- a/cpp/velox/substrait/SubstraitToVeloxPlan.cc +++ b/cpp/velox/substrait/SubstraitToVeloxPlan.cc @@ -26,6 +26,7 @@ #include "utils/ConfigExtractor.h" #include "config/GlutenConfig.h" +#include "operators/plannodes/RowVectorStream.h" namespace gluten { namespace { @@ -710,16 +711,23 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait:: namespace { void extractUnnestFieldExpr( - std::shared_ptr projNode, + std::shared_ptr child, int32_t index, std::vector& unnestFields) { - auto name = projNode->names()[index]; - auto expr = projNode->projections()[index]; - auto type = expr->type(); + if (auto projNode = std::dynamic_pointer_cast(child)) { + auto name = projNode->names()[index]; + auto expr = projNode->projections()[index]; + auto type = expr->type(); - auto unnestFieldExpr = std::make_shared(type, name); - VELOX_CHECK_NOT_NULL(unnestFieldExpr, " the key in unnest Operator only support field"); - unnestFields.emplace_back(unnestFieldExpr); + auto unnestFieldExpr = std::make_shared(type, name); + VELOX_CHECK_NOT_NULL(unnestFieldExpr, " the key in unnest Operator only support field"); + unnestFields.emplace_back(unnestFieldExpr); + } else { + auto name = child->outputType()->names()[index]; + auto field = child->outputType()->childAt(index); + auto unnestFieldExpr = std::make_shared(field, name); + unnestFields.emplace_back(unnestFieldExpr); + } } } // namespace @@ -752,10 +760,13 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait:: SubstraitParser::configSetInOptimization(generateRel.advanced_extension(), "injectedProject="); if (injectedProject) { - auto projNode = std::dynamic_pointer_cast(childNode); + // Child should be either ProjectNode or ValueStreamNode in case of project fallback. VELOX_CHECK( - projNode != nullptr && projNode->names().size() > requiredChildOutput.size(), - "injectedProject is true, but the Project is missing or does not have the corresponding projection field") + (std::dynamic_pointer_cast(childNode) != nullptr || + std::dynamic_pointer_cast(childNode) != nullptr) && + childNode->outputType()->size() > requiredChildOutput.size(), + "injectedProject is true, but the ProjectNode or ValueStreamNode (in case of projection fallback)" + " is missing or does not have the corresponding projection field") bool isStack = generateRel.has_advanced_extension() && SubstraitParser::configSetInOptimization(generateRel.advanced_extension(), "isStack="); @@ -768,7 +779,8 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait:: // +- Project [fake_column#128, [1,2,3] AS _pre_0#129] // +- RewrittenNodeWall Scan OneRowRelation[fake_column#128] // The last projection column in GeneratorRel's child(Project) is the column we need to unnest - extractUnnestFieldExpr(projNode, projNode->projections().size() - 1, unnest); + auto index = childNode->outputType()->size() - 1; + extractUnnestFieldExpr(childNode, index, unnest); } else { // For stack function, e.g. stack(2, 1,2,3), a sample // input substrait plan is like the following: @@ -782,10 +794,10 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait:: auto generatorFunc = generator.scalar_function(); auto numRows = SubstraitParser::getLiteralValue(generatorFunc.arguments(0).value().literal()); auto numFields = static_cast(std::ceil((generatorFunc.arguments_size() - 1.0) / numRows)); - auto totalProjectCount = projNode->names().size(); + auto totalProjectCount = childNode->outputType()->size(); for (auto i = totalProjectCount - numFields; i < totalProjectCount; ++i) { - extractUnnestFieldExpr(projNode, i, unnest); + extractUnnestFieldExpr(childNode, i, unnest); } } } else { diff --git a/cpp/velox/symbols.map b/cpp/velox/symbols.map index ebd2b9af0096..525faf3526a1 100644 --- a/cpp/velox/symbols.map +++ b/cpp/velox/symbols.map @@ -6,6 +6,8 @@ }; Java_org_apache_gluten_*; + JNI_OnLoad; + JNI_OnUnload; local: # Hide symbols of static dependencies *; diff --git a/cpp/velox/utils/ConfigExtractor.cc b/cpp/velox/utils/ConfigExtractor.cc index a71f143225b9..816166351c0e 100644 --- a/cpp/velox/utils/ConfigExtractor.cc +++ b/cpp/velox/utils/ConfigExtractor.cc @@ -34,6 +34,13 @@ const bool kVeloxFileHandleCacheEnabledDefault = false; // Log granularity of AWS C++ SDK const std::string kVeloxAwsSdkLogLevel = "spark.gluten.velox.awsSdkLogLevel"; const std::string kVeloxAwsSdkLogLevelDefault = "FATAL"; +// Retry mode for AWS s3 +const std::string kVeloxS3RetryMode = "spark.gluten.velox.fs.s3a.retry.mode"; +const std::string kVeloxS3RetryModeDefault = "legacy"; +// Connection timeout for AWS s3 +const std::string kVeloxS3ConnectTimeout = "spark.gluten.velox.fs.s3a.connect.timeout"; +// Using default fs.s3a.connection.timeout value in hadoop +const std::string kVeloxS3ConnectTimeoutDefault = "200s"; } // namespace namespace gluten { @@ -64,6 +71,10 @@ std::shared_ptr getHiveConfig(std::shared_ptr< bool useInstanceCredentials = conf->get("spark.hadoop.fs.s3a.use.instance.credentials", false); std::string iamRole = conf->get("spark.hadoop.fs.s3a.iam.role", ""); std::string iamRoleSessionName = conf->get("spark.hadoop.fs.s3a.iam.role.session.name", ""); + std::string retryMaxAttempts = conf->get("spark.hadoop.fs.s3a.retry.limit", "20"); + std::string retryMode = conf->get(kVeloxS3RetryMode, kVeloxS3RetryModeDefault); + std::string maxConnections = conf->get("spark.hadoop.fs.s3a.connection.maximum", "15"); + std::string connectTimeout = conf->get(kVeloxS3ConnectTimeout, kVeloxS3ConnectTimeoutDefault); std::string awsSdkLogLevel = conf->get(kVeloxAwsSdkLogLevel, kVeloxAwsSdkLogLevelDefault); @@ -79,6 +90,14 @@ std::shared_ptr getHiveConfig(std::shared_ptr< if (envAwsEndpoint != nullptr) { awsEndpoint = std::string(envAwsEndpoint); } + const char* envRetryMaxAttempts = std::getenv("AWS_MAX_ATTEMPTS"); + if (envRetryMaxAttempts != nullptr) { + retryMaxAttempts = std::string(envRetryMaxAttempts); + } + const char* envRetryMode = std::getenv("AWS_RETRY_MODE"); + if (envRetryMode != nullptr) { + retryMode = std::string(envRetryMode); + } if (useInstanceCredentials) { hiveConfMap[facebook::velox::connector::hive::HiveConfig::kS3UseInstanceCredentials] = "true"; @@ -98,6 +117,10 @@ std::shared_ptr getHiveConfig(std::shared_ptr< hiveConfMap[facebook::velox::connector::hive::HiveConfig::kS3SSLEnabled] = sslEnabled ? "true" : "false"; hiveConfMap[facebook::velox::connector::hive::HiveConfig::kS3PathStyleAccess] = pathStyleAccess ? "true" : "false"; hiveConfMap[facebook::velox::connector::hive::HiveConfig::kS3LogLevel] = awsSdkLogLevel; + hiveConfMap[facebook::velox::connector::hive::HiveConfig::kS3MaxAttempts] = retryMaxAttempts; + hiveConfMap[facebook::velox::connector::hive::HiveConfig::kS3RetryMode] = retryMode; + hiveConfMap[facebook::velox::connector::hive::HiveConfig::kS3MaxConnections] = maxConnections; + hiveConfMap[facebook::velox::connector::hive::HiveConfig::kS3ConnectTimeout] = connectTimeout; #endif #ifdef ENABLE_GCS diff --git a/dev/vcpkg/CONTRIBUTING.md b/dev/vcpkg/CONTRIBUTING.md index b725f0b50fc5..719bc91db066 100644 --- a/dev/vcpkg/CONTRIBUTING.md +++ b/dev/vcpkg/CONTRIBUTING.md @@ -13,7 +13,7 @@ Please init vcpkg env first: Vcpkg already maintains a lot of libraries. You can find them by vcpkg cli. -(NOTE: Please always use cli beacause [packages on vcpkg.io](https://vcpkg.io/en/packages.html) is outdate). +(NOTE: Please always use cli because [packages on vcpkg.io](https://vcpkg.io/en/packages.html) is outdate). ``` $ ./.vcpkg/vcpkg search folly @@ -28,7 +28,7 @@ folly[zlib] Support zlib for compression folly[zstd] Support zstd for compression ``` -`[...]` means additional features. Then add depend into [vcpkg.json](./vcpkg.json). +`[...]` means additional features. Then add the dependency into [vcpkg.json](./vcpkg.json). ``` json { @@ -144,7 +144,7 @@ See [vcpkg.json reference](https://learn.microsoft.com/en-us/vcpkg/reference/vcp `portfile.cmake` is a cmake script describing how to build and install the package. A typical portfile has 3 stages: -**Download and perpare source**: +**Download and prepare source**: ``` cmake # Download from Github diff --git a/dev/vcpkg/ports/jemalloc/fix-configure-ac.patch b/dev/vcpkg/ports/jemalloc/fix-configure-ac.patch new file mode 100644 index 000000000000..7799dfb9e80e --- /dev/null +++ b/dev/vcpkg/ports/jemalloc/fix-configure-ac.patch @@ -0,0 +1,13 @@ +diff --git a/configure.ac b/configure.ac +index f6d25f334..3115504e2 100644 +--- a/configure.ac ++++ b/configure.ac +@@ -1592,7 +1592,7 @@ fi + [enable_uaf_detection="0"] + ) + if test "x$enable_uaf_detection" = "x1" ; then +- AC_DEFINE([JEMALLOC_UAF_DETECTION], [ ]) ++ AC_DEFINE([JEMALLOC_UAF_DETECTION], [ ], ["enable UAF"]) + fi + AC_SUBST([enable_uaf_detection]) + diff --git a/dev/vcpkg/ports/jemalloc/portfile.cmake b/dev/vcpkg/ports/jemalloc/portfile.cmake new file mode 100644 index 000000000000..6cac12ca3b7c --- /dev/null +++ b/dev/vcpkg/ports/jemalloc/portfile.cmake @@ -0,0 +1,79 @@ +vcpkg_from_github( + OUT_SOURCE_PATH SOURCE_PATH + REPO jemalloc/jemalloc + REF 54eaed1d8b56b1aa528be3bdd1877e59c56fa90c + SHA512 527bfbf5db9a5c2b7b04df4785b6ae9d445cff8cb17298bf3e550c88890d2bd7953642d8efaa417580610508279b527d3a3b9e227d17394fd2013c88cb7ae75a + HEAD_REF master + PATCHES + fix-configure-ac.patch + preprocessor.patch +) +if(VCPKG_TARGET_IS_WINDOWS) + set(opts "ac_cv_search_log=none required" + "--without-private-namespace" + "--with-jemalloc-prefix=je_gluten_" + "--with-private-namespace=je_gluten_private_" + "--without-export" + "--disable-shared" + "--disable-cxx" + "--disable-libdl" + # For fixing an issue when loading native lib: cannot allocate memory in static TLS block. + "--disable-initial-exec-tls" + "CFLAGS=-fPIC" + "CXXFLAGS=-fPIC") +else() + set(opts + "--with-jemalloc-prefix=je_gluten_" + "--with-private-namespace=je_gluten_private_" + "--without-export" + "--disable-shared" + "--disable-cxx" + "--disable-libdl" + # For fixing an issue when loading native lib: cannot allocate memory in static TLS block. + "--disable-initial-exec-tls" + "CFLAGS=-fPIC" + "CXXFLAGS=-fPIC") +endif() + +vcpkg_configure_make( + SOURCE_PATH "${SOURCE_PATH}" + AUTOCONFIG + NO_WRAPPERS + OPTIONS ${opts} +) + +vcpkg_install_make() + +if(VCPKG_TARGET_IS_WINDOWS) + file(COPY "${SOURCE_PATH}/include/msvc_compat/strings.h" DESTINATION "${CURRENT_PACKAGES_DIR}/include/jemalloc/msvc_compat") + vcpkg_replace_string("${CURRENT_PACKAGES_DIR}/include/jemalloc/jemalloc.h" "" "\"msvc_compat/strings.h\"") + if(VCPKG_LIBRARY_LINKAGE STREQUAL "dynamic") + file(COPY "${CURRENT_BUILDTREES_DIR}/${TARGET_TRIPLET}-rel/lib/jemalloc.lib" DESTINATION "${CURRENT_PACKAGES_DIR}/lib") + file(MAKE_DIRECTORY "${CURRENT_PACKAGES_DIR}/bin") + file(RENAME "${CURRENT_PACKAGES_DIR}/lib/jemalloc.dll" "${CURRENT_PACKAGES_DIR}/bin/jemalloc.dll") + endif() + if(NOT VCPKG_BUILD_TYPE) + if(VCPKG_LIBRARY_LINKAGE STREQUAL "dynamic") + file(COPY "${CURRENT_BUILDTREES_DIR}/${TARGET_TRIPLET}-dbg/lib/jemalloc.lib" DESTINATION "${CURRENT_PACKAGES_DIR}/debug/lib") + file(MAKE_DIRECTORY "${CURRENT_PACKAGES_DIR}/debug/bin") + file(RENAME "${CURRENT_PACKAGES_DIR}/debug/lib/jemalloc.dll" "${CURRENT_PACKAGES_DIR}/debug/bin/jemalloc.dll") + endif() + endif() + if(VCPKG_LIBRARY_LINKAGE STREQUAL "static") + vcpkg_replace_string("${CURRENT_PACKAGES_DIR}/lib/pkgconfig/jemalloc.pc" "install_suffix=" "install_suffix=_s") + if(NOT VCPKG_BUILD_TYPE) + vcpkg_replace_string("${CURRENT_PACKAGES_DIR}/debug/lib/pkgconfig/jemalloc.pc" "install_suffix=" "install_suffix=_s") + endif() + endif() +endif() + +vcpkg_fixup_pkgconfig() + +vcpkg_copy_pdbs() + +file(REMOVE_RECURSE "${CURRENT_PACKAGES_DIR}/debug/include") +file(REMOVE_RECURSE "${CURRENT_PACKAGES_DIR}/debug/share") +file(REMOVE_RECURSE "${CURRENT_PACKAGES_DIR}/tools") + +# Handle copyright +file(INSTALL "${SOURCE_PATH}/COPYING" DESTINATION "${CURRENT_PACKAGES_DIR}/share/${PORT}" RENAME copyright) diff --git a/dev/vcpkg/ports/jemalloc/preprocessor.patch b/dev/vcpkg/ports/jemalloc/preprocessor.patch new file mode 100644 index 000000000000..6e6e2d1403fb --- /dev/null +++ b/dev/vcpkg/ports/jemalloc/preprocessor.patch @@ -0,0 +1,12 @@ +diff --git a/configure.ac b/configure.ac +index 3115504e2..ffb504b08 100644 +--- a/configure.ac ++++ b/configure.ac +@@ -749,6 +749,7 @@ case "${host}" in + so="dll" + if test "x$je_cv_msvc" = "xyes" ; then + importlib="lib" ++ JE_APPEND_VS(CPPFLAGS, -DJEMALLOC_NO_PRIVATE_NAMESPACE) + DSO_LDFLAGS="-LD" + EXTRA_LDFLAGS="-link -DEBUG" + CTARGET='-Fo$@' diff --git a/dev/vcpkg/ports/jemalloc/vcpkg.json b/dev/vcpkg/ports/jemalloc/vcpkg.json new file mode 100644 index 000000000000..007e05b931c9 --- /dev/null +++ b/dev/vcpkg/ports/jemalloc/vcpkg.json @@ -0,0 +1,8 @@ +{ + "name": "jemalloc", + "version": "5.3.0", + "port-version": 1, + "description": "jemalloc is a general purpose malloc(3) implementation that emphasizes fragmentation avoidance and scalable concurrency support", + "homepage": "https://jemalloc.net/", + "license": "BSD-2-Clause" +} diff --git a/docs/Configuration.md b/docs/Configuration.md index 089675286f68..2c2bd4de11f2 100644 --- a/docs/Configuration.md +++ b/docs/Configuration.md @@ -89,6 +89,8 @@ The following configurations are related to Velox settings. | spark.gluten.sql.columnar.backend.velox.maxCoalescedBytes | Set the max coalesced bytes for velox file scan. | | | spark.gluten.sql.columnar.backend.velox.cachePrefetchMinPct | Set prefetch cache min pct for velox file scan. | | | spark.gluten.velox.awsSdkLogLevel | Log granularity of AWS C++ SDK in velox. | FATAL | +| spark.gluten.velox.fs.s3a.retry.mode | Retry mode for AWS s3 connection error, can be "legacy", "standard" and "adaptive". | legacy | +| spark.gluten.velox.fs.s3a.connect.timeout | Timeout for AWS s3 connection. | 1s | | spark.gluten.sql.columnar.backend.velox.orc.scan.enabled | Enable velox orc scan. If disabled, vanilla spark orc scan will be used. | true | | spark.gluten.sql.complexType.scan.fallback.enabled | Force fallback for complex type scan, including struct, map, array. | true | diff --git a/docs/get-started/ClickHouse.md b/docs/get-started/ClickHouse.md index 4352a99e55f9..ab24de7a4fd6 100644 --- a/docs/get-started/ClickHouse.md +++ b/docs/get-started/ClickHouse.md @@ -679,13 +679,13 @@ spark.shuffle.manager=org.apache.spark.shuffle.gluten.celeborn.CelebornShuffleMa quickly start a celeborn cluster ```shell -wget https://archive.apache.org/dist/incubator/celeborn/celeborn-0.3.0-incubating/apache-celeborn-0.3.0-incubating-bin.tgz && \ -tar -zxvf apache-celeborn-0.3.0-incubating-bin.tgz && \ -mv apache-celeborn-0.3.0-incubating-bin/conf/celeborn-defaults.conf.template apache-celeborn-0.3.0-incubating-bin/conf/celeborn-defaults.conf && \ -mv apache-celeborn-0.3.0-incubating-bin/conf/log4j2.xml.template apache-celeborn-0.3.0-incubating-bin/conf/log4j2.xml && \ +wget https://archive.apache.org/dist/celeborn/celeborn-0.3.2-incubating/apache-celeborn-0.3.2-incubating-bin.tgz && \ +tar -zxvf apache-celeborn-0.3.2-incubating-bin.tgz && \ +mv apache-celeborn-0.3.2-incubating-bin/conf/celeborn-defaults.conf.template apache-celeborn-0.3.2-incubating-bin/conf/celeborn-defaults.conf && \ +mv apache-celeborn-0.3.2-incubating-bin/conf/log4j2.xml.template apache-celeborn-0.3.2-incubating-bin/conf/log4j2.xml && \ mkdir /opt/hadoop && chmod 777 /opt/hadoop && \ -echo -e "celeborn.worker.flusher.threads 4\nceleborn.worker.storage.dirs /tmp\nceleborn.worker.monitor.disk.enabled false" > apache-celeborn-0.3.0-incubating-bin/conf/celeborn-defaults.conf && \ -bash apache-celeborn-0.3.0-incubating-bin/sbin/start-master.sh && bash apache-celeborn-0.3.0-incubating-bin/sbin/start-worker.sh +echo -e "celeborn.worker.flusher.threads 4\nceleborn.worker.storage.dirs /tmp\nceleborn.worker.monitor.disk.enabled false" > apache-celeborn-0.3.2-incubating-bin/conf/celeborn-defaults.conf && \ +bash apache-celeborn-0.3.2-incubating-bin/sbin/start-master.sh && bash apache-celeborn-0.3.2-incubating-bin/sbin/start-worker.sh ``` ### Columnar shuffle mode diff --git a/docs/get-started/build-guide.md b/docs/get-started/build-guide.md index 3db2244ba229..b2e4b9560301 100644 --- a/docs/get-started/build-guide.md +++ b/docs/get-started/build-guide.md @@ -14,7 +14,7 @@ Please set them via `--`, e.g. `--build_type=Release`. | build_tests | Build gluten cpp tests. | OFF | | build_examples | Build udf example. | OFF | | build_benchmarks | Build gluten cpp benchmarks. | OFF | -| build_jemalloc | Build with jemalloc. | ON | +| build_jemalloc | Build with jemalloc. | OFF | | build_protobuf | Build protobuf lib. | ON | | enable_qat | Enable QAT for shuffle data de/compression. | OFF | | enable_iaa | Enable IAA for shuffle data de/compression. | OFF | diff --git a/ep/build-velox/src/get_velox.sh b/ep/build-velox/src/get_velox.sh index a0a7baa0da45..a96719dc10fc 100755 --- a/ep/build-velox/src/get_velox.sh +++ b/ep/build-velox/src/get_velox.sh @@ -17,7 +17,7 @@ set -exu VELOX_REPO=https://github.com/oap-project/velox.git -VELOX_BRANCH=2024_06_21 +VELOX_BRANCH=2024_06_26 VELOX_HOME="" #Set on run gluten on HDFS diff --git a/gluten-celeborn/common/src/main/java/org/apache/spark/shuffle/gluten/celeborn/CelebornShuffleManager.java b/gluten-celeborn/common/src/main/java/org/apache/spark/shuffle/gluten/celeborn/CelebornShuffleManager.java index f454cf00c656..63fb0cc1b9bd 100644 --- a/gluten-celeborn/common/src/main/java/org/apache/spark/shuffle/gluten/celeborn/CelebornShuffleManager.java +++ b/gluten-celeborn/common/src/main/java/org/apache/spark/shuffle/gluten/celeborn/CelebornShuffleManager.java @@ -16,6 +16,7 @@ */ package org.apache.spark.shuffle.gluten.celeborn; +import org.apache.gluten.GlutenConfig; import org.apache.gluten.backendsapi.BackendsApiManager; import org.apache.gluten.exception.GlutenException; @@ -194,9 +195,14 @@ public ShuffleHandle registerShuffle( if (dependency instanceof ColumnarShuffleDependency) { if (fallbackPolicyRunner.applyAllFallbackPolicy( lifecycleManager, dependency.partitioner().numPartitions())) { - logger.warn("Fallback to ColumnarShuffleManager!"); - columnarShuffleIds.add(shuffleId); - return columnarShuffleManager().registerShuffle(shuffleId, dependency); + if (GlutenConfig.getConf().enableCelebornFallback()) { + logger.warn("Fallback to ColumnarShuffleManager!"); + columnarShuffleIds.add(shuffleId); + return columnarShuffleManager().registerShuffle(shuffleId, dependency); + } else { + throw new GlutenException( + "The Celeborn service(Master: " + celebornConf.masterHost() + ") is unavailable"); + } } else { return registerCelebornShuffleHandle(shuffleId, dependency); } @@ -217,7 +223,13 @@ public boolean unregisterShuffle(int shuffleId) { } } return CelebornUtils.unregisterShuffle( - lifecycleManager, shuffleClient, shuffleIdTracker, shuffleId, appUniqueId, isDriver()); + lifecycleManager, + shuffleClient, + shuffleIdTracker, + shuffleId, + appUniqueId, + throwsFetchFailure, + isDriver()); } @Override diff --git a/gluten-celeborn/common/src/main/java/org/apache/spark/shuffle/gluten/celeborn/CelebornUtils.java b/gluten-celeborn/common/src/main/java/org/apache/spark/shuffle/gluten/celeborn/CelebornUtils.java index 4593d019c27e..6b4229ad3037 100644 --- a/gluten-celeborn/common/src/main/java/org/apache/spark/shuffle/gluten/celeborn/CelebornUtils.java +++ b/gluten-celeborn/common/src/main/java/org/apache/spark/shuffle/gluten/celeborn/CelebornUtils.java @@ -49,11 +49,21 @@ public static boolean unregisterShuffle( Object shuffleIdTracker, int appShuffleId, String appUniqueId, + boolean throwsFetchFailure, boolean isDriver) { try { - // for Celeborn 0.4.0 try { - if (lifecycleManager != null) { + try { + // for Celeborn 0.4.1 + if (lifecycleManager != null) { + Method unregisterAppShuffle = + lifecycleManager + .getClass() + .getMethod("unregisterAppShuffle", int.class, boolean.class); + unregisterAppShuffle.invoke(lifecycleManager, appShuffleId, throwsFetchFailure); + } + } catch (NoSuchMethodException ex) { + // for Celeborn 0.4.0 Method unregisterAppShuffle = lifecycleManager.getClass().getMethod("unregisterAppShuffle", int.class); unregisterAppShuffle.invoke(lifecycleManager, appShuffleId); @@ -65,7 +75,7 @@ public static boolean unregisterShuffle( unregisterAppShuffleId.invoke(shuffleIdTracker, shuffleClient, appShuffleId); } return true; - } catch (NoSuchMethodException ex) { + } catch (NoSuchMethodException | ClassNotFoundException ex) { try { if (lifecycleManager != null) { Method unregisterShuffleMethod = diff --git a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargets.java b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargets.java index 2d6fc0748464..c3ece743310a 100644 --- a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargets.java +++ b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargets.java @@ -63,6 +63,6 @@ public static MemoryTarget newConsumer( factory = TreeMemoryConsumers.shared(); } - return dynamicOffHeapSizingIfEnabled(factory.newConsumer(tmm, name, spillers, virtualChildren)); + return factory.newConsumer(tmm, name, spillers, virtualChildren); } } diff --git a/gluten-core/src/main/java/org/apache/gluten/vectorized/JniWorkspace.java b/gluten-core/src/main/java/org/apache/gluten/vectorized/JniWorkspace.java index a7c12387a221..810c945d35ab 100644 --- a/gluten-core/src/main/java/org/apache/gluten/vectorized/JniWorkspace.java +++ b/gluten-core/src/main/java/org/apache/gluten/vectorized/JniWorkspace.java @@ -75,14 +75,14 @@ private static JniWorkspace createDefault() { } } - public static void enableDebug() { + public static void enableDebug(String debugDir) { // Preserve the JNI libraries even after process exits. // This is useful for debugging native code if the debug symbols were embedded in // the libraries. synchronized (DEFAULT_INSTANCE_INIT_LOCK) { if (DEBUG_INSTANCE == null) { final File tempRoot = - Paths.get("/tmp").resolve("gluten-jni-debug-" + UUID.randomUUID()).toFile(); + Paths.get(debugDir).resolve("gluten-jni-debug-" + UUID.randomUUID()).toFile(); try { FileUtils.forceMkdir(tempRoot); } catch (IOException e) { diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala index 9a37c4a40dd1..3ca5e0313924 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala @@ -430,7 +430,9 @@ trait SparkPlanExecApi { * * @return */ - def genExtendedColumnarPostRules(): List[SparkSession => Rule[SparkPlan]] + def genExtendedColumnarPostRules(): List[SparkSession => Rule[SparkPlan]] = { + SparkShimLoader.getSparkShims.getExtendedColumnarPostRules() ::: List() + } def genInjectPostHocResolutionRules(): List[SparkSession => Rule[LogicalPlan]] diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala index 3bbd99c50a6a..9d231bbc2891 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/BasicScanExecTransformer.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.hive.HiveTableScanExecTransformer import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType} import com.google.protobuf.StringValue +import io.substrait.proto.NamedStruct import scala.collection.JavaConverters._ @@ -109,19 +110,19 @@ trait BasicScanExecTransformer extends LeafTransformSupport with BaseDataSource } override protected def doTransform(context: SubstraitContext): TransformContext = { - val output = filteRedundantField(outputAttributes()) + val output = filterRedundantField(outputAttributes()) val typeNodes = ConverterUtils.collectAttributeTypeNodes(output) val nameList = ConverterUtils.collectAttributeNamesWithoutExprId(output) val columnTypeNodes = output.map { attr => if (getPartitionSchema.exists(_.name.equals(attr.name))) { - new ColumnTypeNode(1) + new ColumnTypeNode(NamedStruct.ColumnType.PARTITION_COL_VALUE) } else if (SparkShimLoader.getSparkShims.isRowIndexMetadataColumn(attr.name)) { - new ColumnTypeNode(3) + new ColumnTypeNode(NamedStruct.ColumnType.ROWINDEX_COL_VALUE) } else if (attr.isMetadataCol) { - new ColumnTypeNode(2) + new ColumnTypeNode(NamedStruct.ColumnType.METADATA_COL_VALUE) } else { - new ColumnTypeNode(0) + new ColumnTypeNode(NamedStruct.ColumnType.NORMAL_COL_VALUE) } }.asJava // Will put all filter expressions into an AND expression @@ -156,8 +157,8 @@ trait BasicScanExecTransformer extends LeafTransformSupport with BaseDataSource TransformContext(output, output, readNode) } - def filteRedundantField(outputs: Seq[Attribute]): Seq[Attribute] = { - var final_output: List[Attribute] = List() + private def filterRedundantField(outputs: Seq[Attribute]): Seq[Attribute] = { + var finalOutput: List[Attribute] = List() val outputList = outputs.toArray for (i <- outputList.indices) { var dup = false @@ -167,9 +168,9 @@ trait BasicScanExecTransformer extends LeafTransformSupport with BaseDataSource } } if (!dup) { - final_output = final_output :+ outputList(i) + finalOutput = finalOutput :+ outputList(i) } } - final_output + finalOutput } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala index b7b0889dc1eb..d5222cfc6350 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.hive.HiveUDFTransformer import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ -trait Transformable extends Unevaluable { +trait Transformable { def getTransformer(childrenTransformers: Seq[ExpressionTransformer]): ExpressionTransformer } @@ -564,7 +564,7 @@ object ExpressionConverter extends SQLConfHelper with Logging { replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), tryEval, - ExpressionNames.CHECK_ADD + ExpressionNames.CHECKED_ADD ) case tryEval @ TryEval(a: Subtract) => BackendsApiManager.getSparkPlanExecApiInstance.genTryArithmeticTransformer( @@ -572,7 +572,7 @@ object ExpressionConverter extends SQLConfHelper with Logging { replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), tryEval, - ExpressionNames.CHECK_SUBTRACT + ExpressionNames.CHECKED_SUBTRACT ) case tryEval @ TryEval(a: Divide) => BackendsApiManager.getSparkPlanExecApiInstance.genTryArithmeticTransformer( @@ -580,7 +580,7 @@ object ExpressionConverter extends SQLConfHelper with Logging { replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), tryEval, - ExpressionNames.CHECK_DIVIDE + ExpressionNames.CHECKED_DIVIDE ) case tryEval @ TryEval(a: Multiply) => BackendsApiManager.getSparkPlanExecApiInstance.genTryArithmeticTransformer( @@ -588,7 +588,7 @@ object ExpressionConverter extends SQLConfHelper with Logging { replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), tryEval, - ExpressionNames.CHECK_MULTIPLY + ExpressionNames.CHECKED_MULTIPLY ) case a: Add => BackendsApiManager.getSparkPlanExecApiInstance.genArithmeticTransformer( @@ -596,7 +596,7 @@ object ExpressionConverter extends SQLConfHelper with Logging { replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), a, - ExpressionNames.CHECK_ADD + ExpressionNames.CHECKED_ADD ) case a: Subtract => BackendsApiManager.getSparkPlanExecApiInstance.genArithmeticTransformer( @@ -604,7 +604,7 @@ object ExpressionConverter extends SQLConfHelper with Logging { replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), a, - ExpressionNames.CHECK_SUBTRACT + ExpressionNames.CHECKED_SUBTRACT ) case a: Multiply => BackendsApiManager.getSparkPlanExecApiInstance.genArithmeticTransformer( @@ -612,7 +612,7 @@ object ExpressionConverter extends SQLConfHelper with Logging { replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), a, - ExpressionNames.CHECK_MULTIPLY + ExpressionNames.CHECKED_MULTIPLY ) case a: Divide => BackendsApiManager.getSparkPlanExecApiInstance.genArithmeticTransformer( @@ -620,7 +620,7 @@ object ExpressionConverter extends SQLConfHelper with Logging { replaceWithExpressionTransformerInternal(a.left, attributeSeq, expressionsMap), replaceWithExpressionTransformerInternal(a.right, attributeSeq, expressionsMap), a, - ExpressionNames.CHECK_DIVIDE + ExpressionNames.CHECKED_DIVIDE ) case tryEval: TryEval => // This is a placeholder to handle try_eval(other expressions). diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala index f0082456fb18..678ba38172eb 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionMappings.scala @@ -101,6 +101,7 @@ object ExpressionMappings { Sig[Encode](ENCODE), Sig[Uuid](UUID), Sig[BitLength](BIT_LENGTH), + Sig[Levenshtein](LEVENSHTEIN), Sig[UnBase64](UNBASE64), Sig[Base64](BASE64), diff --git a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RemoveFilter.scala b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RemoveFilter.scala index 5d7209dfbfb4..e2b8439fd218 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RemoveFilter.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/enumerated/RemoveFilter.scala @@ -42,6 +42,7 @@ object RemoveFilter extends RasRule[SparkPlan] { val filter = node.asInstanceOf[FilterExecTransformerBase] if (filter.isNoop()) { val out = NoopFilter(filter.child, filter.output) + out.copyTagsFrom(filter) return List(out) } List.empty diff --git a/gluten-core/src/main/scala/org/apache/gluten/planner/metadata/GlutenMetadata.scala b/gluten-core/src/main/scala/org/apache/gluten/planner/metadata/GlutenMetadata.scala index e25f0a1f1c06..f66c5290e95f 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/planner/metadata/GlutenMetadata.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/planner/metadata/GlutenMetadata.scala @@ -18,42 +18,18 @@ package org.apache.gluten.planner.metadata import org.apache.gluten.ras.Metadata -import org.apache.spark.sql.catalyst.expressions.Attribute - sealed trait GlutenMetadata extends Metadata { - import GlutenMetadata._ def schema(): Schema + def logicalLink(): LogicalLink } object GlutenMetadata { - def apply(schema: Schema): Metadata = { - Impl(schema) + def apply(schema: Schema, logicalLink: LogicalLink): Metadata = { + Impl(schema, logicalLink) } - private case class Impl(override val schema: Schema) extends GlutenMetadata - - case class Schema(output: Seq[Attribute]) { - private val hash = output.map(_.semanticHash()).hashCode() - - override def hashCode(): Int = { - hash - } - - override def equals(obj: Any): Boolean = obj match { - case other: Schema => - semanticEquals(other) - case _ => - false - } - - private def semanticEquals(other: Schema): Boolean = { - if (output.size != other.output.size) { - return false - } - output.zip(other.output).forall { - case (left, right) => - left.semanticEquals(right) - } - } + private case class Impl(override val schema: Schema, override val logicalLink: LogicalLink) + extends GlutenMetadata { + override def toString: String = s"$schema,$logicalLink" } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/planner/metadata/GlutenMetadataModel.scala b/gluten-core/src/main/scala/org/apache/gluten/planner/metadata/GlutenMetadataModel.scala index 6d1baa79db17..7b95f1383d04 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/planner/metadata/GlutenMetadataModel.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/planner/metadata/GlutenMetadataModel.scala @@ -16,7 +16,6 @@ */ package org.apache.gluten.planner.metadata -import org.apache.gluten.planner.metadata.GlutenMetadata.Schema import org.apache.gluten.planner.plan.GlutenPlanModel.GroupLeafExec import org.apache.gluten.ras.{Metadata, MetadataModel} @@ -31,18 +30,22 @@ object GlutenMetadataModel extends Logging { private object MetadataModelImpl extends MetadataModel[SparkPlan] { override def metadataOf(node: SparkPlan): Metadata = node match { case g: GroupLeafExec => throw new UnsupportedOperationException() - case other => GlutenMetadata(Schema(other.output)) + case other => + GlutenMetadata( + Schema(other.output), + other.logicalLink.map(LogicalLink(_)).getOrElse(LogicalLink.notFound)) } - override def dummy(): Metadata = GlutenMetadata(Schema(List())) + override def dummy(): Metadata = GlutenMetadata(Schema(List()), LogicalLink.notFound) override def verify(one: Metadata, other: Metadata): Unit = (one, other) match { - case (left: GlutenMetadata, right: GlutenMetadata) if left.schema() != right.schema() => - // We apply loose restriction on schema. Since Gluten still have some customized - // logics causing schema of an operator to change after being transformed. - // For example: https://github.com/apache/incubator-gluten/pull/5171 - logWarning(s"Warning: Schema mismatch: one: ${left.schema()}, other: ${right.schema()}") - case (left: GlutenMetadata, right: GlutenMetadata) if left == right => + case (left: GlutenMetadata, right: GlutenMetadata) => + implicitly[Verifier[Schema]].verify(left.schema(), right.schema()) + implicitly[Verifier[LogicalLink]].verify(left.logicalLink(), right.logicalLink()) case _ => throw new IllegalStateException(s"Metadata mismatch: one: $one, other $other") } } + + trait Verifier[T <: Any] { + def verify(one: T, other: T): Unit + } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/planner/metadata/LogicalLink.scala b/gluten-core/src/main/scala/org/apache/gluten/planner/metadata/LogicalLink.scala new file mode 100644 index 000000000000..4c3bffd471ad --- /dev/null +++ b/gluten-core/src/main/scala/org/apache/gluten/planner/metadata/LogicalLink.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.planner.metadata + +import org.apache.gluten.planner.metadata.GlutenMetadataModel.Verifier + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} + +case class LogicalLink(plan: LogicalPlan) { + override def hashCode(): Int = System.identityHashCode(plan) + override def equals(obj: Any): Boolean = obj match { + // LogicalLink's comparison is based on ref equality of the logical plans being compared. + case LogicalLink(otherPlan) => plan eq otherPlan + case _ => false + } + + override def toString: String = s"${plan.nodeName}[${plan.stats.simpleString}]" +} + +object LogicalLink { + private case class LogicalLinkNotFound() extends logical.LeafNode { + override def output: Seq[Attribute] = List.empty + override def canEqual(that: Any): Boolean = throw new UnsupportedOperationException() + override def computeStats(): Statistics = Statistics(sizeInBytes = 0) + } + + val notFound = new LogicalLink(LogicalLinkNotFound()) + implicit val verifier: Verifier[LogicalLink] = new Verifier[LogicalLink] with Logging { + override def verify(one: LogicalLink, other: LogicalLink): Unit = { + // LogicalLink's comparison is based on ref equality of the logical plans being compared. + if (one != other) { + logWarning(s"Warning: Logical link mismatch: one: $one, other: $other") + } + } + } +} diff --git a/gluten-core/src/main/scala/org/apache/gluten/planner/metadata/Schema.scala b/gluten-core/src/main/scala/org/apache/gluten/planner/metadata/Schema.scala new file mode 100644 index 000000000000..969d34d5cc82 --- /dev/null +++ b/gluten-core/src/main/scala/org/apache/gluten/planner/metadata/Schema.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.planner.metadata + +import org.apache.gluten.planner.metadata.GlutenMetadataModel.Verifier + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.Attribute + +case class Schema(output: Seq[Attribute]) { + private val hash = output.map(_.semanticHash()).hashCode() + + override def hashCode(): Int = { + hash + } + + override def equals(obj: Any): Boolean = obj match { + case other: Schema => + semanticEquals(other) + case _ => + false + } + + private def semanticEquals(other: Schema): Boolean = { + if (output.size != other.output.size) { + return false + } + output.zip(other.output).forall { + case (left, right) => + left.semanticEquals(right) + } + } + + override def toString: String = { + output.toString() + } +} + +object Schema { + implicit val verifier: Verifier[Schema] = new Verifier[Schema] with Logging { + override def verify(one: Schema, other: Schema): Unit = { + if (one != other) { + // We apply loose restriction on schema. Since Gluten still have some customized + // logics causing schema of an operator to change after being transformed. + // For example: https://github.com/apache/incubator-gluten/pull/5171 + logWarning(s"Warning: Schema mismatch: one: $one, other: $other") + } + } + } +} diff --git a/gluten-core/src/main/scala/org/apache/gluten/planner/property/Conv.scala b/gluten-core/src/main/scala/org/apache/gluten/planner/property/Conv.scala index 475f6292094c..18db0f959491 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/planner/property/Conv.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/planner/property/Conv.scala @@ -99,6 +99,7 @@ case class ConvEnforcerRule(reqConv: Conv) extends RasRule[SparkPlan] { } val transition = Conv.findTransition(conv, reqConv) val after = transition.apply(node) + after.copyTagsFrom(node) List(after) } diff --git a/gluten-core/src/main/scala/org/apache/gluten/utils/Iterators.scala b/gluten-core/src/main/scala/org/apache/gluten/utils/Iterators.scala deleted file mode 100644 index 1e3681355d6c..000000000000 --- a/gluten-core/src/main/scala/org/apache/gluten/utils/Iterators.scala +++ /dev/null @@ -1,228 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.gluten.utils - -import org.apache.spark.{InterruptibleIterator, TaskContext} -import org.apache.spark.util.TaskResources - -import java.util.concurrent.TimeUnit -import java.util.concurrent.atomic.AtomicBoolean - -private class PayloadCloser[A](in: Iterator[A])(closeCallback: A => Unit) extends Iterator[A] { - private var closer: Option[() => Unit] = None - - TaskResources.addRecycler("Iterators#PayloadCloser", 100) { - tryClose() - } - - override def hasNext: Boolean = { - tryClose() - in.hasNext - } - - override def next(): A = { - val a: A = in.next() - closer.synchronized { - closer = Some( - () => { - closeCallback.apply(a) - }) - } - a - } - - private def tryClose(): Unit = { - closer.synchronized { - closer match { - case Some(c) => c.apply() - case None => - } - closer = None // make sure the payload is closed once - } - } -} - -private class IteratorCompleter[A](in: Iterator[A])(completionCallback: => Unit) - extends Iterator[A] { - private val completed = new AtomicBoolean(false) - - TaskResources.addRecycler("Iterators#IteratorRecycler", 100) { - tryComplete() - } - - override def hasNext: Boolean = { - val out = in.hasNext - if (!out) { - tryComplete() - } - out - } - - override def next(): A = { - in.next() - } - - private def tryComplete(): Unit = { - if (!completed.compareAndSet(false, true)) { - return // make sure the iterator is completed once - } - completionCallback - } -} - -private class LifeTimeAccumulator[A](in: Iterator[A], onCollected: Long => Unit) - extends Iterator[A] { - private val closed = new AtomicBoolean(false) - private val startTime = System.nanoTime() - - TaskResources.addRecycler("Iterators#LifeTimeAccumulator", 100) { - tryFinish() - } - - override def hasNext: Boolean = { - val out = in.hasNext - if (!out) { - tryFinish() - } - out - } - - override def next(): A = { - in.next() - } - - private def tryFinish(): Unit = { - // pipeline metric should only be calculate once. - if (!closed.compareAndSet(false, true)) { - return - } - val lifeTime = TimeUnit.NANOSECONDS.toMillis( - System.nanoTime() - startTime - ) - onCollected(lifeTime) - } -} - -private class ReadTimeAccumulator[A](in: Iterator[A], onAdded: Long => Unit) extends Iterator[A] { - - override def hasNext: Boolean = { - val prev = System.nanoTime() - val out = in.hasNext - val after = System.nanoTime() - val duration = TimeUnit.NANOSECONDS.toMillis(after - prev) - onAdded(duration) - out - } - - override def next(): A = { - val prev = System.nanoTime() - val out = in.next() - val after = System.nanoTime() - val duration = TimeUnit.NANOSECONDS.toMillis(after - prev) - onAdded(duration) - out - } -} - -/** - * To protect the wrapped iterator to avoid undesired order of calls to its `hasNext` and `next` - * methods. - */ -private class InvocationFlowProtection[A](in: Iterator[A]) extends Iterator[A] { - sealed private trait State - private case object Init extends State - private case class HasNextCalled(hasNext: Boolean) extends State - private case object NextCalled extends State - - private var state: State = Init - - override def hasNext: Boolean = { - val out = state match { - case Init | NextCalled => - in.hasNext - case HasNextCalled(lastHasNext) => - lastHasNext - } - state = HasNextCalled(out) - out - } - - override def next(): A = { - val out = state match { - case Init | NextCalled => - if (!in.hasNext) { - throw new IllegalStateException("End of stream") - } - in.next() - case HasNextCalled(lastHasNext) => - if (!lastHasNext) { - throw new IllegalStateException("End of stream") - } - in.next() - } - state = NextCalled - out - } -} - -class WrapperBuilder[A](in: Iterator[A]) { // FIXME how to make the ctor companion-private? - private var wrapped: Iterator[A] = in - - def recyclePayload(closeCallback: (A) => Unit): WrapperBuilder[A] = { - wrapped = new PayloadCloser(wrapped)(closeCallback) - this - } - - def recycleIterator(completionCallback: => Unit): WrapperBuilder[A] = { - wrapped = new IteratorCompleter(wrapped)(completionCallback) - this - } - - def collectLifeMillis(onCollected: Long => Unit): WrapperBuilder[A] = { - wrapped = new LifeTimeAccumulator[A](wrapped, onCollected) - this - } - - def collectReadMillis(onAdded: Long => Unit): WrapperBuilder[A] = { - wrapped = new ReadTimeAccumulator[A](wrapped, onAdded) - this - } - - def asInterruptible(context: TaskContext): WrapperBuilder[A] = { - wrapped = new InterruptibleIterator[A](context, wrapped) - this - } - - def protectInvocationFlow(): WrapperBuilder[A] = { - wrapped = new InvocationFlowProtection[A](wrapped) - this - } - - def create(): Iterator[A] = { - wrapped - } -} - -/** - * Utility class to provide iterator wrappers for non-trivial use cases. E.g. iterators that manage - * payload's lifecycle. - */ -object Iterators { - def wrap[A](in: Iterator[A]): WrapperBuilder[A] = { - new WrapperBuilder[A](in) - } -} diff --git a/gluten-core/src/main/scala/org/apache/gluten/utils/SubstraitPlanPrinterUtil.scala b/gluten-core/src/main/scala/org/apache/gluten/utils/SubstraitPlanPrinterUtil.scala index 77d5d55f618d..a6ec7cb21fbf 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/utils/SubstraitPlanPrinterUtil.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/utils/SubstraitPlanPrinterUtil.scala @@ -24,37 +24,34 @@ import io.substrait.proto.{NamedStruct, Plan} object SubstraitPlanPrinterUtil extends Logging { - /** Transform Substrait Plan to json format. */ - def substraitPlanToJson(substraintPlan: Plan): String = { + private def typeRegistry( + d: com.google.protobuf.Descriptors.Descriptor): com.google.protobuf.TypeRegistry = { val defaultRegistry = WrappersProto.getDescriptor.getMessageTypes - val registry = com.google.protobuf.TypeRegistry + com.google.protobuf.TypeRegistry .newBuilder() - .add(substraintPlan.getDescriptorForType()) + .add(d) .add(defaultRegistry) .build() - JsonFormat.printer.usingTypeRegistry(registry).print(substraintPlan) + } + private def MessageToJson(message: com.google.protobuf.Message): String = { + val registry = typeRegistry(message.getDescriptorForType) + JsonFormat.printer.usingTypeRegistry(registry).print(message) } - def substraitNamedStructToJson(substraintPlan: NamedStruct): String = { - val defaultRegistry = WrappersProto.getDescriptor.getMessageTypes - val registry = com.google.protobuf.TypeRegistry - .newBuilder() - .add(substraintPlan.getDescriptorForType()) - .add(defaultRegistry) - .build() - JsonFormat.printer.usingTypeRegistry(registry).print(substraintPlan) + /** Transform Substrait Plan to json format. */ + def substraitPlanToJson(substraitPlan: Plan): String = { + MessageToJson(substraitPlan) + } + + def substraitNamedStructToJson(namedStruct: NamedStruct): String = { + MessageToJson(namedStruct) } /** Transform substrait plan json string to PlanNode */ def jsonToSubstraitPlan(planJson: String): Plan = { try { val builder = Plan.newBuilder() - val defaultRegistry = WrappersProto.getDescriptor.getMessageTypes - val registry = com.google.protobuf.TypeRegistry - .newBuilder() - .add(builder.getDescriptorForType) - .add(defaultRegistry) - .build() + val registry = typeRegistry(builder.getDescriptorForType) JsonFormat.parser().usingTypeRegistry(registry).merge(planJson, builder) builder.build() } catch { diff --git a/gluten-core/src/main/scala/org/apache/gluten/utils/iterator/Iterators.scala b/gluten-core/src/main/scala/org/apache/gluten/utils/iterator/Iterators.scala new file mode 100644 index 000000000000..eedfa66cfeaf --- /dev/null +++ b/gluten-core/src/main/scala/org/apache/gluten/utils/iterator/Iterators.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.utils.iterator + +import org.apache.gluten.utils.iterator.IteratorsV1.WrapperBuilderV1 + +import org.apache.spark.TaskContext + +/** + * Utility class to provide iterator wrappers for non-trivial use cases. E.g. iterators that manage + * payload's lifecycle. + */ +object Iterators { + sealed trait Version + case object V1 extends Version + + private val DEFAULT_VERSION: Version = V1 + + trait WrapperBuilder[A] { + def recyclePayload(closeCallback: (A) => Unit): WrapperBuilder[A] + def recycleIterator(completionCallback: => Unit): WrapperBuilder[A] + def collectLifeMillis(onCollected: Long => Unit): WrapperBuilder[A] + def collectReadMillis(onAdded: Long => Unit): WrapperBuilder[A] + def asInterruptible(context: TaskContext): WrapperBuilder[A] + def protectInvocationFlow(): WrapperBuilder[A] + def create(): Iterator[A] + } + + def wrap[A](in: Iterator[A]): WrapperBuilder[A] = { + wrap(V1, in) + } + + def wrap[A](version: Version, in: Iterator[A]): WrapperBuilder[A] = { + version match { + case V1 => + new WrapperBuilderV1[A](in) + } + } +} diff --git a/gluten-core/src/main/scala/org/apache/gluten/utils/iterator/IteratorsV1.scala b/gluten-core/src/main/scala/org/apache/gluten/utils/iterator/IteratorsV1.scala new file mode 100644 index 000000000000..3e9248c44458 --- /dev/null +++ b/gluten-core/src/main/scala/org/apache/gluten/utils/iterator/IteratorsV1.scala @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.utils.iterator + +import org.apache.gluten.utils.iterator.Iterators.WrapperBuilder + +import org.apache.spark.{InterruptibleIterator, TaskContext} +import org.apache.spark.util.TaskResources + +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicBoolean + +object IteratorsV1 { + private class PayloadCloser[A](in: Iterator[A])(closeCallback: A => Unit) extends Iterator[A] { + private var closer: Option[() => Unit] = None + + TaskResources.addRecycler("Iterators#PayloadCloser", 100) { + tryClose() + } + + override def hasNext: Boolean = { + tryClose() + in.hasNext + } + + override def next(): A = { + val a: A = in.next() + closer.synchronized { + closer = Some( + () => { + closeCallback.apply(a) + }) + } + a + } + + private def tryClose(): Unit = { + closer.synchronized { + closer match { + case Some(c) => c.apply() + case None => + } + closer = None // make sure the payload is closed once + } + } + } + + private class IteratorCompleter[A](in: Iterator[A])(completionCallback: => Unit) + extends Iterator[A] { + private val completed = new AtomicBoolean(false) + + TaskResources.addRecycler("Iterators#IteratorRecycler", 100) { + tryComplete() + } + + override def hasNext: Boolean = { + val out = in.hasNext + if (!out) { + tryComplete() + } + out + } + + override def next(): A = { + in.next() + } + + private def tryComplete(): Unit = { + if (!completed.compareAndSet(false, true)) { + return // make sure the iterator is completed once + } + completionCallback + } + } + + private class LifeTimeAccumulator[A](in: Iterator[A], onCollected: Long => Unit) + extends Iterator[A] { + private val closed = new AtomicBoolean(false) + private val startTime = System.nanoTime() + + TaskResources.addRecycler("Iterators#LifeTimeAccumulator", 100) { + tryFinish() + } + + override def hasNext: Boolean = { + val out = in.hasNext + if (!out) { + tryFinish() + } + out + } + + override def next(): A = { + in.next() + } + + private def tryFinish(): Unit = { + // pipeline metric should only be calculate once. + if (!closed.compareAndSet(false, true)) { + return + } + val lifeTime = TimeUnit.NANOSECONDS.toMillis( + System.nanoTime() - startTime + ) + onCollected(lifeTime) + } + } + + private class ReadTimeAccumulator[A](in: Iterator[A], onAdded: Long => Unit) extends Iterator[A] { + + override def hasNext: Boolean = { + val prev = System.nanoTime() + val out = in.hasNext + val after = System.nanoTime() + val duration = TimeUnit.NANOSECONDS.toMillis(after - prev) + onAdded(duration) + out + } + + override def next(): A = { + val prev = System.nanoTime() + val out = in.next() + val after = System.nanoTime() + val duration = TimeUnit.NANOSECONDS.toMillis(after - prev) + onAdded(duration) + out + } + } + + /** + * To protect the wrapped iterator to avoid undesired order of calls to its `hasNext` and `next` + * methods. + */ + private class InvocationFlowProtection[A](in: Iterator[A]) extends Iterator[A] { + sealed private trait State + private case object Init extends State + private case class HasNextCalled(hasNext: Boolean) extends State + private case object NextCalled extends State + + private var state: State = Init + + override def hasNext: Boolean = { + val out = state match { + case Init | NextCalled => + in.hasNext + case HasNextCalled(lastHasNext) => + lastHasNext + } + state = HasNextCalled(out) + out + } + + override def next(): A = { + val out = state match { + case Init | NextCalled => + if (!in.hasNext) { + throw new IllegalStateException("End of stream") + } + in.next() + case HasNextCalled(lastHasNext) => + if (!lastHasNext) { + throw new IllegalStateException("End of stream") + } + in.next() + } + state = NextCalled + out + } + } + + class WrapperBuilderV1[A] private[iterator] (in: Iterator[A]) extends WrapperBuilder[A] { + private var wrapped: Iterator[A] = in + + override def recyclePayload(closeCallback: (A) => Unit): WrapperBuilder[A] = { + wrapped = new PayloadCloser(wrapped)(closeCallback) + this + } + + override def recycleIterator(completionCallback: => Unit): WrapperBuilder[A] = { + wrapped = new IteratorCompleter(wrapped)(completionCallback) + this + } + + override def collectLifeMillis(onCollected: Long => Unit): WrapperBuilder[A] = { + wrapped = new LifeTimeAccumulator[A](wrapped, onCollected) + this + } + + override def collectReadMillis(onAdded: Long => Unit): WrapperBuilder[A] = { + wrapped = new ReadTimeAccumulator[A](wrapped, onAdded) + this + } + + override def asInterruptible(context: TaskContext): WrapperBuilder[A] = { + wrapped = new InterruptibleIterator[A](context, wrapped) + this + } + + override def protectInvocationFlow(): WrapperBuilder[A] = { + wrapped = new InvocationFlowProtection[A](wrapped) + this + } + + override def create(): Iterator[A] = { + wrapped + } + } +} diff --git a/gluten-core/src/main/scala/org/apache/spark/HdfsConfGenerator.scala b/gluten-core/src/main/scala/org/apache/spark/HdfsConfGenerator.scala index 9756837d96e5..04272517e5bb 100644 --- a/gluten-core/src/main/scala/org/apache/spark/HdfsConfGenerator.scala +++ b/gluten-core/src/main/scala/org/apache/spark/HdfsConfGenerator.scala @@ -41,8 +41,8 @@ object HdfsConfGenerator extends Logging { addFileMethod.invoke(sc, path, Boolean.box(false), Boolean.box(true), Boolean.box(false)) // Overwrite the spark internal config `spark.app.initial.file.urls`, // so that the file can be available before initializing executor plugin. - assert(sc.addedFiles.nonEmpty) - sc.conf.set("spark.app.initial.file.urls", sc.addedFiles.keys.toSeq.mkString(",")) + assert(sc.listFiles.nonEmpty) + sc.conf.set("spark.app.initial.file.urls", sc.listFiles().mkString(",")) } private def ignoreKey(key: String): Boolean = { diff --git a/gluten-core/src/main/scala/org/apache/spark/sql/execution/datasources/GlutenWriterColumnarRules.scala b/gluten-core/src/main/scala/org/apache/spark/sql/execution/datasources/GlutenWriterColumnarRules.scala index f9ad5201d8db..7063c3f67b80 100644 --- a/gluten-core/src/main/scala/org/apache/spark/sql/execution/datasources/GlutenWriterColumnarRules.scala +++ b/gluten-core/src/main/scala/org/apache/spark/sql/execution/datasources/GlutenWriterColumnarRules.scala @@ -162,19 +162,28 @@ object GlutenWriterColumnarRules { if write.getClass.getName == NOOP_WRITE && BackendsApiManager.getSettings.enableNativeWriteFiles() => injectFakeRowAdaptor(rc, rc.child) - case rc @ DataWritingCommandExec(cmd, child) - if BackendsApiManager.getSettings.supportNativeWrite(child.output.toStructType.fields) => - val format = getNativeFormat(cmd) - session.sparkContext.setLocalProperty( - "staticPartitionWriteOnly", - BackendsApiManager.getSettings.staticPartitionWriteOnly().toString) - // FIXME: We should only use context property if having no other approaches. - // Should see if there is another way to pass these options. - session.sparkContext.setLocalProperty("isNativeAppliable", format.isDefined.toString) - session.sparkContext.setLocalProperty("nativeFormat", format.getOrElse("")) - if (format.isDefined) { - injectFakeRowAdaptor(rc, child) + case rc @ DataWritingCommandExec(cmd, child) => + if (BackendsApiManager.getSettings.supportNativeWrite(child.output.toStructType.fields)) { + val format = getNativeFormat(cmd) + session.sparkContext.setLocalProperty( + "staticPartitionWriteOnly", + BackendsApiManager.getSettings.staticPartitionWriteOnly().toString) + // FIXME: We should only use context property if having no other approaches. + // Should see if there is another way to pass these options. + session.sparkContext.setLocalProperty("isNativeAppliable", format.isDefined.toString) + session.sparkContext.setLocalProperty("nativeFormat", format.getOrElse("")) + if (format.isDefined) { + injectFakeRowAdaptor(rc, child) + } else { + rc.withNewChildren(rc.children.map(apply)) + } } else { + session.sparkContext.setLocalProperty( + "staticPartitionWriteOnly", + BackendsApiManager.getSettings.staticPartitionWriteOnly().toString) + session.sparkContext.setLocalProperty("isNativeAppliable", "false") + session.sparkContext.setLocalProperty("nativeFormat", "") + rc.withNewChildren(rc.children.map(apply)) } case plan: SparkPlan => plan.withNewChildren(plan.children.map(apply)) diff --git a/gluten-core/src/test/scala/org/apache/gluten/utils/IteratorSuite.scala b/gluten-core/src/test/scala/org/apache/gluten/utils/iterator/IteratorSuite.scala similarity index 86% rename from gluten-core/src/test/scala/org/apache/gluten/utils/IteratorSuite.scala rename to gluten-core/src/test/scala/org/apache/gluten/utils/iterator/IteratorSuite.scala index 389e2adfefd4..1a84d671922d 100644 --- a/gluten-core/src/test/scala/org/apache/gluten/utils/IteratorSuite.scala +++ b/gluten-core/src/test/scala/org/apache/gluten/utils/iterator/IteratorSuite.scala @@ -14,18 +14,25 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.gluten.utils +package org.apache.gluten.utils.iterator + +import org.apache.gluten.utils.iterator.Iterators.{V1, WrapperBuilder} import org.apache.spark.util.TaskResources import org.scalatest.funsuite.AnyFunSuite -class IteratorSuite extends AnyFunSuite { +class IteratorV1Suite extends IteratorSuite { + override protected def wrap[A](in: Iterator[A]): WrapperBuilder[A] = Iterators.wrap(V1, in) +} + +abstract class IteratorSuite extends AnyFunSuite { + protected def wrap[A](in: Iterator[A]): WrapperBuilder[A] + test("Trivial wrapping") { val strings = Array[String]("one", "two", "three") val itr = strings.toIterator - val wrapped = Iterators - .wrap(itr) + val wrapped = wrap(itr) .create() assertResult(strings) { wrapped.toArray @@ -37,8 +44,7 @@ class IteratorSuite extends AnyFunSuite { TaskResources.runUnsafe { val strings = Array[String]("one", "two", "three") val itr = strings.toIterator - val wrapped = Iterators - .wrap(itr) + val wrapped = wrap(itr) .recycleIterator { completeCount += 1 } @@ -56,8 +62,7 @@ class IteratorSuite extends AnyFunSuite { TaskResources.runUnsafe { val strings = Array[String]("one", "two", "three") val itr = strings.toIterator - val _ = Iterators - .wrap(itr) + val _ = wrap(itr) .recycleIterator { completeCount += 1 } @@ -72,8 +77,7 @@ class IteratorSuite extends AnyFunSuite { TaskResources.runUnsafe { val strings = Array[String]("one", "two", "three") val itr = strings.toIterator - val wrapped = Iterators - .wrap(itr) + val wrapped = wrap(itr) .recyclePayload { _: String => closeCount += 1 } .create() assertResult(strings) { @@ -89,8 +93,7 @@ class IteratorSuite extends AnyFunSuite { TaskResources.runUnsafe { val strings = Array[String]("one", "two", "three") val itr = strings.toIterator - val wrapped = Iterators - .wrap(itr) + val wrapped = wrap(itr) .recyclePayload { _: String => closeCount += 1 } .create() assertResult(strings.take(2)) { @@ -115,8 +118,7 @@ class IteratorSuite extends AnyFunSuite { new Object } } - val wrapped = Iterators - .wrap(itr) + val wrapped = wrap(itr) .protectInvocationFlow() .create() wrapped.hasNext diff --git a/gluten-core/src/test/scala/org/apache/spark/utils/iterator/IteratorBenchmark.scala b/gluten-core/src/test/scala/org/apache/spark/utils/iterator/IteratorBenchmark.scala new file mode 100644 index 000000000000..aa69f309aac8 --- /dev/null +++ b/gluten-core/src/test/scala/org/apache/spark/utils/iterator/IteratorBenchmark.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.utils.iterator + +import org.apache.gluten.utils.iterator.Iterators +import org.apache.gluten.utils.iterator.Iterators.V1 + +import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} +import org.apache.spark.util.TaskResources + +object IteratorBenchmark extends BenchmarkBase { + + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { + runBenchmark("Iterator Nesting") { + TaskResources.runUnsafe { + val nPayloads: Int = 50000000 // 50 millions + + def makeScalaIterator: Iterator[Any] = { + (0 until nPayloads).view.map { _: Int => new Object }.iterator + } + + def compareIterator(name: String)( + makeGlutenIterator: Iterators.Version => Iterator[Any]): Unit = { + val benchmark = new Benchmark(name, nPayloads, output = output) + benchmark.addCase("Scala Iterator") { + _ => + val count = makeScalaIterator.count(_ => true) + assert(count == nPayloads) + } + benchmark.addCase("Gluten Iterator V1") { + _ => + val count = makeGlutenIterator(V1).count(_ => true) + assert(count == nPayloads) + } + benchmark.run() + } + + compareIterator("0 Levels Nesting") { + version => + Iterators + .wrap(version, makeScalaIterator) + .create() + } + compareIterator("1 Levels Nesting - read") { + version => + Iterators + .wrap(version, makeScalaIterator) + .collectReadMillis { _ => } + .create() + } + compareIterator("5 Levels Nesting - read") { + version => + Iterators + .wrap(version, makeScalaIterator) + .collectReadMillis { _ => } + .collectReadMillis { _ => } + .collectReadMillis { _ => } + .collectReadMillis { _ => } + .collectReadMillis { _ => } + .create() + } + compareIterator("10 Levels Nesting - read") { + version => + Iterators + .wrap(version, makeScalaIterator) + .collectReadMillis { _ => } + .collectReadMillis { _ => } + .collectReadMillis { _ => } + .collectReadMillis { _ => } + .collectReadMillis { _ => } + .collectReadMillis { _ => } + .collectReadMillis { _ => } + .collectReadMillis { _ => } + .collectReadMillis { _ => } + .collectReadMillis { _ => } + .create() + } + compareIterator("1 Levels Nesting - recycle") { + version => + Iterators + .wrap(version, makeScalaIterator) + .recycleIterator {} + .create() + } + compareIterator("5 Levels Nesting - recycle") { + version => + Iterators + .wrap(version, makeScalaIterator) + .recycleIterator {} + .recycleIterator {} + .recycleIterator {} + .recycleIterator {} + .recycleIterator {} + .create() + } + compareIterator("10 Levels Nesting - recycle") { + version => + Iterators + .wrap(version, makeScalaIterator) + .recycleIterator {} + .recycleIterator {} + .recycleIterator {} + .recycleIterator {} + .recycleIterator {} + .recycleIterator {} + .recycleIterator {} + .recycleIterator {} + .recycleIterator {} + .recycleIterator {} + .create() + } + } + } + } +} diff --git a/gluten-data/src/main/java/org/apache/gluten/memory/arrow/alloc/ArrowBufferAllocators.java b/gluten-data/src/main/java/org/apache/gluten/memory/arrow/alloc/ArrowBufferAllocators.java index efee20e48b83..51f49da704eb 100644 --- a/gluten-data/src/main/java/org/apache/gluten/memory/arrow/alloc/ArrowBufferAllocators.java +++ b/gluten-data/src/main/java/org/apache/gluten/memory/arrow/alloc/ArrowBufferAllocators.java @@ -60,11 +60,12 @@ public static class ArrowBufferAllocatorManager implements TaskResource { listener = new ManagedAllocationListener( MemoryTargets.throwOnOom( - MemoryTargets.newConsumer( - tmm, - "ArrowContextInstance", - Collections.emptyList(), - Collections.emptyMap())), + MemoryTargets.dynamicOffHeapSizingIfEnabled( + MemoryTargets.newConsumer( + tmm, + "ArrowContextInstance", + Collections.emptyList(), + Collections.emptyMap()))), TaskResources.getSharedUsage()); } diff --git a/gluten-data/src/main/java/org/apache/gluten/memory/nmm/NativeMemoryManagers.java b/gluten-data/src/main/java/org/apache/gluten/memory/nmm/NativeMemoryManagers.java index 928f869ba4e1..37456badd42f 100644 --- a/gluten-data/src/main/java/org/apache/gluten/memory/nmm/NativeMemoryManagers.java +++ b/gluten-data/src/main/java/org/apache/gluten/memory/nmm/NativeMemoryManagers.java @@ -26,6 +26,8 @@ import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.util.TaskResources; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.util.Arrays; import java.util.Collections; @@ -37,6 +39,7 @@ import java.util.stream.Stream; public final class NativeMemoryManagers { + private static final Logger LOG = LoggerFactory.getLogger(NativeMemoryManagers.class); // TODO: Let all caller support spill. public static NativeMemoryManager contextInstance(String name) { @@ -67,86 +70,92 @@ private static NativeMemoryManager createNativeMemoryManager( final MemoryTarget target = MemoryTargets.throwOnOom( MemoryTargets.overAcquire( - MemoryTargets.newConsumer( - tmm, - name, - // call memory manager's shrink API, if no good then call the spiller - Stream.concat( - Stream.of( - new Spiller() { - @Override - public long spill(MemoryTarget self, long size) { - return Optional.of(out.get()) - .map(nmm -> nmm.shrink(size)) - .orElseThrow( - () -> - new IllegalStateException( - "" - + "Shrink is requested before native " - + "memory manager is created. Try moving " - + "any actions about memory allocation out " - + "from the memory manager constructor.")); - } + MemoryTargets.dynamicOffHeapSizingIfEnabled( + MemoryTargets.newConsumer( + tmm, + name, + // call memory manager's shrink API, if no good then call the spiller + Stream.concat( + Stream.of( + new Spiller() { + @Override + public long spill(MemoryTarget self, long size) { + return Optional.ofNullable(out.get()) + .map(nmm -> nmm.shrink(size)) + .orElseGet( + () -> { + LOG.warn( + "Shrink is requested before native " + + "memory manager is created. Try moving " + + "any actions about memory allocation" + + " out from the memory manager" + + " constructor."); + return 0L; + }); + } - @Override - public Set applicablePhases() { - return Spillers.PHASE_SET_SHRINK_ONLY; - } - }), - spillers.stream()) - .map(spiller -> Spillers.withMinSpillSize(spiller, reservationBlockSize)) - .collect(Collectors.toList()), - Collections.singletonMap( - "single", - new MemoryUsageRecorder() { - @Override - public void inc(long bytes) { - // no-op - } + @Override + public Set applicablePhases() { + return Spillers.PHASE_SET_SHRINK_ONLY; + } + }), + spillers.stream()) + .map( + spiller -> Spillers.withMinSpillSize(spiller, reservationBlockSize)) + .collect(Collectors.toList()), + Collections.singletonMap( + "single", + new MemoryUsageRecorder() { + @Override + public void inc(long bytes) { + // no-op + } - @Override - public long peak() { - throw new UnsupportedOperationException("Not implemented"); - } + @Override + public long peak() { + throw new UnsupportedOperationException("Not implemented"); + } - @Override - public long current() { - throw new UnsupportedOperationException("Not implemented"); - } + @Override + public long current() { + throw new UnsupportedOperationException("Not implemented"); + } - @Override - public MemoryUsageStats toStats() { - return getNativeMemoryManager().collectMemoryUsage(); - } + @Override + public MemoryUsageStats toStats() { + return getNativeMemoryManager().collectMemoryUsage(); + } - private NativeMemoryManager getNativeMemoryManager() { - return Optional.of(out.get()) - .orElseThrow( - () -> - new IllegalStateException( - "" - + "Memory usage stats are requested before native " - + "memory manager is created. Try moving any " - + "actions about memory allocation out from the " - + "memory manager constructor.")); - } - })), - MemoryTargets.newConsumer( - tmm, - "OverAcquire.DummyTarget", - Collections.singletonList( - new Spiller() { - @Override - public long spill(MemoryTarget self, long size) { - return self.repay(size); - } + private NativeMemoryManager getNativeMemoryManager() { + return Optional.ofNullable(out.get()) + .orElseThrow( + () -> + new IllegalStateException( + "" + + "Memory usage stats are requested before" + + " native memory manager is created. Try" + + " moving any actions about memory" + + " allocation out from the memory manager" + + " constructor.")); + } + }))), + MemoryTargets.dynamicOffHeapSizingIfEnabled( + MemoryTargets.newConsumer( + tmm, + "OverAcquire.DummyTarget", + Collections.singletonList( + new Spiller() { + @Override + public long spill(MemoryTarget self, long size) { + return self.repay(size); + } - @Override - public Set applicablePhases() { - return Spillers.PHASE_SET_ALL; - } - }), - Collections.emptyMap()), + @Override + public Set applicablePhases() { + return Spillers.PHASE_SET_ALL; + } + }), + Collections.emptyMap())), overAcquiredRatio)); // listener ManagedReservationListener rl = diff --git a/gluten-data/src/main/java/org/apache/gluten/vectorized/NativePlanEvaluator.java b/gluten-data/src/main/java/org/apache/gluten/vectorized/NativePlanEvaluator.java index e54724a599c1..2ac048b2b960 100644 --- a/gluten-data/src/main/java/org/apache/gluten/vectorized/NativePlanEvaluator.java +++ b/gluten-data/src/main/java/org/apache/gluten/vectorized/NativePlanEvaluator.java @@ -71,7 +71,7 @@ public GeneralOutIterator createKernelWithBatchIterator( @Override public long spill(MemoryTarget self, long size) { ColumnarBatchOutIterator instance = - Optional.of(outIterator.get()) + Optional.ofNullable(outIterator.get()) .orElseThrow( () -> new IllegalStateException( diff --git a/gluten-data/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala b/gluten-data/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala index 9d9f5ab1765c..840f8618b0b4 100644 --- a/gluten-data/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala +++ b/gluten-data/src/main/scala/org/apache/spark/sql/execution/ColumnarBuildSideRelation.scala @@ -21,7 +21,8 @@ import org.apache.gluten.exec.Runtimes import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators import org.apache.gluten.memory.nmm.NativeMemoryManagers import org.apache.gluten.sql.shims.SparkShimLoader -import org.apache.gluten.utils.{ArrowAbiUtil, Iterators} +import org.apache.gluten.utils.ArrowAbiUtil +import org.apache.gluten.utils.iterator.Iterators import org.apache.gluten.vectorized.{ColumnarBatchSerializerJniWrapper, NativeColumnarToRowJniWrapper} import org.apache.spark.sql.catalyst.InternalRow diff --git a/gluten-data/src/main/scala/org/apache/spark/sql/execution/utils/ExecUtil.scala b/gluten-data/src/main/scala/org/apache/spark/sql/execution/utils/ExecUtil.scala index 083915f12db9..090b8fa2562a 100644 --- a/gluten-data/src/main/scala/org/apache/spark/sql/execution/utils/ExecUtil.scala +++ b/gluten-data/src/main/scala/org/apache/spark/sql/execution/utils/ExecUtil.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.utils import org.apache.gluten.columnarbatch.ColumnarBatches import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators import org.apache.gluten.memory.nmm.NativeMemoryManagers -import org.apache.gluten.utils.Iterators +import org.apache.gluten.utils.iterator.Iterators import org.apache.gluten.vectorized.{ArrowWritableColumnVector, NativeColumnarToRowInfo, NativeColumnarToRowJniWrapper, NativePartitioning} import org.apache.spark.{Partitioner, RangePartitioner, ShuffleDependency} diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/GroupBasedBestFinder.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/GroupBasedBestFinder.scala index effebd41bb3b..1128ab8dec01 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/GroupBasedBestFinder.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/GroupBasedBestFinder.scala @@ -82,15 +82,23 @@ private object GroupBasedBestFinder { return Some(KnownCostPath(ras, path)) } val childrenGroups = can.getChildrenGroups(allGroups).map(gn => allGroups(gn.groupId())) - val maybeBestChildrenPaths: Seq[Option[RasPath[T]]] = childrenGroups.map { - childGroup => childrenGroupsOutput(childGroup).map(kcg => kcg.best().rasPath) + val maybeBestChildrenPaths: Seq[Option[KnownCostPath[T]]] = childrenGroups.map { + childGroup => childrenGroupsOutput(childGroup).map(kcg => kcg.best()) } if (maybeBestChildrenPaths.exists(_.isEmpty)) { // Node should only be solved when all children outputs exist. return None } val bestChildrenPaths = maybeBestChildrenPaths.map(_.get) - Some(KnownCostPath(ras, path.RasPath(ras, can, bestChildrenPaths).get)) + val kcp = KnownCostPath(ras, path.RasPath(ras, can, bestChildrenPaths.map(_.rasPath)).get) + // Cost should be in monotonically increasing basis. + bestChildrenPaths.map(_.cost).foreach { + childCost => + assert( + ras.costModel.costComparator().gteq(kcp.cost, childCost), + "Illegal decreasing cost") + } + Some(kcp) } override def solveGroup( diff --git a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/OperationSuite.scala b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/OperationSuite.scala index 60ec2eedd410..e1ccfa1f44aa 100644 --- a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/OperationSuite.scala +++ b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/OperationSuite.scala @@ -230,7 +230,7 @@ class OperationSuite extends AnyFunSuite { 48, Unary3(48, Unary3(48, Unary3(48, Unary3(48, Unary3(48, Unary3(48, Leaf(30)))))))))))) assert(costModel.costOfCount == 32) // TODO reduce this for performance - assert(costModel.costCompareCount == 20) // TODO reduce this for performance + assert(costModel.costCompareCount == 50) // TODO reduce this for performance } test("Cost evaluation count - max cost") { @@ -292,7 +292,7 @@ class OperationSuite extends AnyFunSuite { 48, Unary3(48, Unary3(48, Unary3(48, Unary3(48, Unary3(48, Unary3(48, Leaf(30)))))))))))) assert(costModel.costOfCount == 32) // TODO reduce this for performance - assert(costModel.costCompareCount == 20) // TODO reduce this for performance + assert(costModel.costCompareCount == 50) // TODO reduce this for performance } } diff --git a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 19c9b2cf478f..d12a40b764f8 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -172,6 +172,7 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("shuffle function - array for primitive type not containing null") .exclude("shuffle function - array for primitive type containing null") .exclude("shuffle function - array for non-primitive type") + .exclude("flatten function") enableSuite[GlutenDataFrameHintSuite] enableSuite[GlutenDataFrameImplicitsSuite] enableSuite[GlutenDataFrameJoinSuite].exclude( @@ -436,7 +437,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("string regex_replace / regex_extract") .exclude("string overlay function") .exclude("binary overlay function") - .exclude("string / binary substring function") .exclude("string parse_url function") enableSuite[GlutenSubquerySuite] .exclude("SPARK-15370: COUNT bug in subquery in subquery in subquery") @@ -674,7 +674,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("Sequence with default step") .exclude("Reverse") .exclude("elementAt") - .exclude("Flatten") .exclude("ArrayRepeat") .exclude("Array remove") .exclude("Array Distinct") @@ -894,7 +893,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("SPARK-34814: LikeSimplification should handle NULL") enableSuite[GlutenSortOrderExpressionsSuite].exclude("SortPrefix") enableSuite[GlutenStringExpressionsSuite] - .exclude("concat") .exclude("StringComparison") .exclude("Substring") .exclude("string substring_index function") @@ -902,23 +900,15 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("string for ascii") .exclude("base64/unbase64 for string") .exclude("encode/decode for string") - .exclude("Levenshtein distance") - .exclude("soundex unit test") - .exclude("replace") .exclude("overlay for string") .exclude("overlay for byte array") .exclude("translate") - .exclude("FORMAT") - .exclude("SPARK-22603: FormatString should not generate codes beyond 64KB") - .exclude("INSTR") .exclude("LOCATE") .exclude("LPAD/RPAD") .exclude("REPEAT") .exclude("length for string / binary") - .exclude("format_number / FormatNumber") .exclude("ParseUrl") .exclude("SPARK-33468: ParseUrl in ANSI mode should fail if input string is not a valid url") - .exclude("Sentences") .excludeGlutenTest("SPARK-40213: ascii for Latin-1 Supplement characters") enableSuite[GlutenTryCastSuite] .exclude("null cast") diff --git a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala index 2b0b40790a76..e64f760ab55f 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala @@ -49,4 +49,86 @@ class GlutenDataFrameFunctionsSuite extends DataFrameFunctionsSuite with GlutenS false ) } + + testGluten("flatten function") { + // Test cases with a primitive type + val intDF = Seq( + (Seq(Seq(1, 2, 3), Seq(4, 5), Seq(6))), + (Seq(Seq(1, 2))), + (Seq(Seq(1), Seq.empty)), + (Seq(Seq.empty, Seq(1))) + ).toDF("i") + + val intDFResult = Seq(Row(Seq(1, 2, 3, 4, 5, 6)), Row(Seq(1, 2)), Row(Seq(1)), Row(Seq(1))) + + def testInt(): Unit = { + checkAnswer(intDF.select(flatten($"i")), intDFResult) + checkAnswer(intDF.selectExpr("flatten(i)"), intDFResult) + } + + // Test with local relation, the Project will be evaluated without codegen + testInt() + // Test with cached relation, the Project will be evaluated with codegen + intDF.cache() + testInt() + + // Test cases with non-primitive types + val strDF = Seq( + (Seq(Seq("a", "b"), Seq("c"), Seq("d", "e", "f"))), + (Seq(Seq("a", "b"))), + (Seq(Seq("a", null), Seq(null, "b"), Seq(null, null))), + (Seq(Seq("a"), Seq.empty)), + (Seq(Seq.empty, Seq("a"))) + ).toDF("s") + + val strDFResult = Seq( + Row(Seq("a", "b", "c", "d", "e", "f")), + Row(Seq("a", "b")), + Row(Seq("a", null, null, "b", null, null)), + Row(Seq("a")), + Row(Seq("a"))) + + def testString(): Unit = { + checkAnswer(strDF.select(flatten($"s")), strDFResult) + checkAnswer(strDF.selectExpr("flatten(s)"), strDFResult) + } + + // Test with local relation, the Project will be evaluated without codegen + testString() + // Test with cached relation, the Project will be evaluated with codegen + strDF.cache() + testString() + + val arrDF = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr") + + def testArray(): Unit = { + checkAnswer( + arrDF.selectExpr("flatten(array(arr, array(null, 5), array(6, null)))"), + Seq(Row(Seq(1, 2, 3, null, 5, 6, null)))) + checkAnswer( + arrDF.selectExpr("flatten(array(array(arr, arr), array(arr)))"), + Seq(Row(Seq(Seq(1, 2, 3), Seq(1, 2, 3), Seq(1, 2, 3))))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArray() + // Test with cached relation, the Project will be evaluated with codegen + arrDF.cache() + testArray() + + // Error test cases + val oneRowDF = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr") + intercept[AnalysisException] { + oneRowDF.select(flatten($"arr")) + } + intercept[AnalysisException] { + oneRowDF.select(flatten($"i")) + } + intercept[AnalysisException] { + oneRowDF.select(flatten($"s")) + } + intercept[AnalysisException] { + oneRowDF.selectExpr("flatten(null)") + } + } } diff --git a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index da71110de3b4..52e7ebcbda49 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -190,6 +190,7 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("shuffle function - array for primitive type not containing null") .exclude("shuffle function - array for primitive type containing null") .exclude("shuffle function - array for non-primitive type") + .exclude("flatten function") enableSuite[GlutenDataFrameHintSuite] enableSuite[GlutenDataFrameImplicitsSuite] enableSuite[GlutenDataFrameJoinSuite].exclude( @@ -457,7 +458,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("string regex_replace / regex_extract") .exclude("string overlay function") .exclude("binary overlay function") - .exclude("string / binary substring function") .exclude("string parse_url function") enableSuite[GlutenSubquerySuite] .exclude("SPARK-15370: COUNT bug in subquery in subquery in subquery") @@ -473,58 +473,9 @@ class ClickHouseTestSettings extends BackendTestSettings { enableSuite[GlutenXPathFunctionsSuite] enableSuite[QueryTestSuite] enableSuite[GlutenAnsiCastSuiteWithAnsiModeOff] - .exclude("null cast") .exclude("cast string to date") - .exclude("cast string to timestamp") - .exclude("cast from boolean") - .exclude("cast from int") - .exclude("cast from long") - .exclude("cast from float") - .exclude("cast from double") - .exclude("cast from timestamp") - .exclude("data type casting") - .exclude("cast and add") - .exclude("from decimal") - .exclude("cast from array") - .exclude("cast from map") - .exclude("cast from struct") - .exclude("cast struct with a timestamp field") - .exclude("cast between string and interval") - .exclude("cast string to boolean") - .exclude("SPARK-20302 cast with same structure") - .exclude("SPARK-22500: cast for struct should not generate codes beyond 64KB") - .exclude("SPARK-27671: cast from nested null type in struct") - .exclude("Process Infinity, -Infinity, NaN in case insensitive manner") - .exclude("SPARK-22825 Cast array to string") - .exclude("SPARK-33291: Cast array with null elements to string") - .exclude("SPARK-22973 Cast map to string") - .exclude("SPARK-22981 Cast struct to string") - .exclude("SPARK-33291: Cast struct with null elements to string") - .exclude("SPARK-34667: cast year-month interval to string") - .exclude("SPARK-34668: cast day-time interval to string") - .exclude("SPARK-35698: cast timestamp without time zone to string") .exclude("SPARK-35711: cast timestamp without time zone to timestamp with local time zone") - .exclude("SPARK-35716: cast timestamp without time zone to date type") - .exclude("SPARK-35718: cast date type to timestamp without timezone") - .exclude("SPARK-35719: cast timestamp with local time zone to timestamp without timezone") - .exclude("SPARK-35720: cast string to timestamp without timezone") - .exclude("SPARK-35112: Cast string to day-time interval") - .exclude("SPARK-35111: Cast string to year-month interval") - .exclude("SPARK-35820: Support cast DayTimeIntervalType in different fields") .exclude("SPARK-35819: Support cast YearMonthIntervalType in different fields") - .exclude("SPARK-35768: Take into account year-month interval fields in cast") - .exclude("SPARK-35735: Take into account day-time interval fields in cast") - .exclude("ANSI mode: Throw exception on casting out-of-range value to byte type") - .exclude("ANSI mode: Throw exception on casting out-of-range value to short type") - .exclude("ANSI mode: Throw exception on casting out-of-range value to int type") - .exclude("ANSI mode: Throw exception on casting out-of-range value to long type") - .exclude("Fast fail for cast string type to decimal type in ansi mode") - .exclude("cast a timestamp before the epoch 1970-01-01 00:00:00Z") - .exclude("cast from array III") - .exclude("cast from map II") - .exclude("cast from map III") - .exclude("cast from struct II") - .exclude("cast from struct III") enableSuite[GlutenAnsiCastSuiteWithAnsiModeOn] .exclude("null cast") .exclude("cast string to date") @@ -714,7 +665,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("Sequence with default step") .exclude("Reverse") .exclude("elementAt") - .exclude("Flatten") .exclude("ArrayRepeat") .exclude("Array remove") .exclude("Array Distinct") @@ -902,7 +852,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("SPARK - 34814: LikeSimplification should handleNULL") enableSuite[GlutenSortOrderExpressionsSuite].exclude("SortPrefix") enableSuite[GlutenStringExpressionsSuite] - .exclude("concat") .exclude("StringComparison") .exclude("Substring") .exclude("string substring_index function") @@ -911,25 +860,14 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("string for ascii") .exclude("base64/unbase64 for string") .exclude("encode/decode for string") - .exclude("Levenshtein distance") - .exclude("soundex unit test") - .exclude("replace") .exclude("overlay for string") .exclude("overlay for byte array") .exclude("translate") - .exclude("FORMAT") - .exclude("SPARK-22603: FormatString should not generate codes beyond 64KB") - .exclude("INSTR") .exclude("LOCATE") - .exclude("LPAD/RPAD") .exclude("REPEAT") .exclude("length for string / binary") - .exclude("format_number / FormatNumber") - .exclude("ToNumber: positive tests") - .exclude("ToNumber: negative tests (the input string does not match the format string)") .exclude("ParseUrl") .exclude("SPARK-33468: ParseUrl in ANSI mode should fail if input string is not a valid url") - .exclude("Sentences") enableSuite[GlutenTryCastSuite] .exclude("null cast") .exclude("cast string to date") diff --git a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala index 2b0b40790a76..e64f760ab55f 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/spark/sql/GlutenDataFrameFunctionsSuite.scala @@ -49,4 +49,86 @@ class GlutenDataFrameFunctionsSuite extends DataFrameFunctionsSuite with GlutenS false ) } + + testGluten("flatten function") { + // Test cases with a primitive type + val intDF = Seq( + (Seq(Seq(1, 2, 3), Seq(4, 5), Seq(6))), + (Seq(Seq(1, 2))), + (Seq(Seq(1), Seq.empty)), + (Seq(Seq.empty, Seq(1))) + ).toDF("i") + + val intDFResult = Seq(Row(Seq(1, 2, 3, 4, 5, 6)), Row(Seq(1, 2)), Row(Seq(1)), Row(Seq(1))) + + def testInt(): Unit = { + checkAnswer(intDF.select(flatten($"i")), intDFResult) + checkAnswer(intDF.selectExpr("flatten(i)"), intDFResult) + } + + // Test with local relation, the Project will be evaluated without codegen + testInt() + // Test with cached relation, the Project will be evaluated with codegen + intDF.cache() + testInt() + + // Test cases with non-primitive types + val strDF = Seq( + (Seq(Seq("a", "b"), Seq("c"), Seq("d", "e", "f"))), + (Seq(Seq("a", "b"))), + (Seq(Seq("a", null), Seq(null, "b"), Seq(null, null))), + (Seq(Seq("a"), Seq.empty)), + (Seq(Seq.empty, Seq("a"))) + ).toDF("s") + + val strDFResult = Seq( + Row(Seq("a", "b", "c", "d", "e", "f")), + Row(Seq("a", "b")), + Row(Seq("a", null, null, "b", null, null)), + Row(Seq("a")), + Row(Seq("a"))) + + def testString(): Unit = { + checkAnswer(strDF.select(flatten($"s")), strDFResult) + checkAnswer(strDF.selectExpr("flatten(s)"), strDFResult) + } + + // Test with local relation, the Project will be evaluated without codegen + testString() + // Test with cached relation, the Project will be evaluated with codegen + strDF.cache() + testString() + + val arrDF = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr") + + def testArray(): Unit = { + checkAnswer( + arrDF.selectExpr("flatten(array(arr, array(null, 5), array(6, null)))"), + Seq(Row(Seq(1, 2, 3, null, 5, 6, null)))) + checkAnswer( + arrDF.selectExpr("flatten(array(array(arr, arr), array(arr)))"), + Seq(Row(Seq(Seq(1, 2, 3), Seq(1, 2, 3), Seq(1, 2, 3))))) + } + + // Test with local relation, the Project will be evaluated without codegen + testArray() + // Test with cached relation, the Project will be evaluated with codegen + arrDF.cache() + testArray() + + // Error test cases + val oneRowDF = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr") + intercept[AnalysisException] { + oneRowDF.select(flatten($"arr")) + } + intercept[AnalysisException] { + oneRowDF.select(flatten($"i")) + } + intercept[AnalysisException] { + oneRowDF.select(flatten($"s")) + } + intercept[AnalysisException] { + oneRowDF.selectExpr("flatten(null)") + } + } } diff --git a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 07af1fa845ca..38ed2c53463b 100644 --- a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -457,7 +457,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("string regex_replace / regex_extract") .exclude("string overlay function") .exclude("binary overlay function") - .exclude("string / binary substring function") .exclude("string parse_url function") enableSuite[GlutenSubquerySuite] .exclude("SPARK-15370: COUNT bug in subquery in subquery in subquery") @@ -756,7 +755,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("SPARK - 34814: LikeSimplification should handleNULL") enableSuite[GlutenSortOrderExpressionsSuite].exclude("SortPrefix") enableSuite[GlutenStringExpressionsSuite] - .exclude("concat") .exclude("StringComparison") .exclude("Substring") .exclude("string substring_index function") @@ -766,24 +764,14 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("base64/unbase64 for string") .exclude("encode/decode for string") .exclude("Levenshtein distance") - .exclude("soundex unit test") - .exclude("replace") .exclude("overlay for string") .exclude("overlay for byte array") .exclude("translate") - .exclude("FORMAT") - .exclude("SPARK-22603: FormatString should not generate codes beyond 64KB") - .exclude("INSTR") .exclude("LOCATE") - .exclude("LPAD/RPAD") .exclude("REPEAT") .exclude("length for string / binary") - .exclude("format_number / FormatNumber") - .exclude("ToNumber: positive tests") - .exclude("ToNumber: negative tests (the input string does not match the format string)") .exclude("ParseUrl") .exclude("SPARK-33468: ParseUrl in ANSI mode should fail if input string is not a valid url") - .exclude("Sentences") enableSuite[GlutenDataSourceV2DataFrameSessionCatalogSuite] enableSuite[GlutenDataSourceV2SQLSessionCatalogSuite] enableSuite[GlutenDataSourceV2SQLSuiteV1Filter] diff --git a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 07af1fa845ca..38ed2c53463b 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -457,7 +457,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("string regex_replace / regex_extract") .exclude("string overlay function") .exclude("binary overlay function") - .exclude("string / binary substring function") .exclude("string parse_url function") enableSuite[GlutenSubquerySuite] .exclude("SPARK-15370: COUNT bug in subquery in subquery in subquery") @@ -756,7 +755,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("SPARK - 34814: LikeSimplification should handleNULL") enableSuite[GlutenSortOrderExpressionsSuite].exclude("SortPrefix") enableSuite[GlutenStringExpressionsSuite] - .exclude("concat") .exclude("StringComparison") .exclude("Substring") .exclude("string substring_index function") @@ -766,24 +764,14 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("base64/unbase64 for string") .exclude("encode/decode for string") .exclude("Levenshtein distance") - .exclude("soundex unit test") - .exclude("replace") .exclude("overlay for string") .exclude("overlay for byte array") .exclude("translate") - .exclude("FORMAT") - .exclude("SPARK-22603: FormatString should not generate codes beyond 64KB") - .exclude("INSTR") .exclude("LOCATE") - .exclude("LPAD/RPAD") .exclude("REPEAT") .exclude("length for string / binary") - .exclude("format_number / FormatNumber") - .exclude("ToNumber: positive tests") - .exclude("ToNumber: negative tests (the input string does not match the format string)") .exclude("ParseUrl") .exclude("SPARK-33468: ParseUrl in ANSI mode should fail if input string is not a valid url") - .exclude("Sentences") enableSuite[GlutenDataSourceV2DataFrameSessionCatalogSuite] enableSuite[GlutenDataSourceV2SQLSessionCatalogSuite] enableSuite[GlutenDataSourceV2SQLSuiteV1Filter] diff --git a/pom.xml b/pom.xml index 81ce0e5d462a..887839ce5fc0 100644 --- a/pom.xml +++ b/pom.xml @@ -53,7 +53,7 @@ delta-core 2.4.0 24 - 0.3.2-incubating + 0.4.1 0.8.0 15.0.0 15.0.0-gluten diff --git a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala index 148e8cdc067c..58b99a7f3064 100644 --- a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala +++ b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala @@ -391,8 +391,7 @@ class GlutenConfig(conf: SQLConf) extends Logging { conf.getConf(COLUMNAR_VELOX_MEMORY_USE_HUGE_PAGES) def debug: Boolean = conf.getConf(DEBUG_ENABLED) - def debugKeepJniWorkspace: Boolean = - conf.getConf(DEBUG_ENABLED) && conf.getConf(DEBUG_KEEP_JNI_WORKSPACE) + def debugKeepJniWorkspace: Boolean = conf.getConf(DEBUG_KEEP_JNI_WORKSPACE) def taskStageId: Int = conf.getConf(BENCHMARK_TASK_STAGEID) def taskPartitionId: Int = conf.getConf(BENCHMARK_TASK_PARTITIONID) def taskId: Long = conf.getConf(BENCHMARK_TASK_TASK_ID) @@ -436,6 +435,10 @@ class GlutenConfig(conf: SQLConf) extends Logging { def awsSdkLogLevel: String = conf.getConf(AWS_SDK_LOG_LEVEL) + def awsS3RetryMode: String = conf.getConf(AWS_S3_RETRY_MODE) + + def awsConnectionTimeout: String = conf.getConf(AWS_S3_CONNECT_TIMEOUT) + def enableCastAvgAggregateFunction: Boolean = conf.getConf(COLUMNAR_NATIVE_CAST_AGGREGATE_ENABLED) def enableGlutenCostEvaluator: Boolean = conf.getConf(COST_EVALUATOR_ENABLED) @@ -444,6 +447,8 @@ class GlutenConfig(conf: SQLConf) extends Logging { conf.getConf(DYNAMIC_OFFHEAP_SIZING_ENABLED) def enableHiveFileFormatWriter: Boolean = conf.getConf(NATIVE_HIVEFILEFORMAT_WRITER_ENABLED) + + def enableCelebornFallback: Boolean = conf.getConf(CELEBORN_FALLBACK_ENABLED) } object GlutenConfig { @@ -488,6 +493,10 @@ object GlutenConfig { val SPARK_S3_IAM: String = HADOOP_PREFIX + S3_IAM_ROLE val S3_IAM_ROLE_SESSION_NAME = "fs.s3a.iam.role.session.name" val SPARK_S3_IAM_SESSION_NAME: String = HADOOP_PREFIX + S3_IAM_ROLE_SESSION_NAME + val S3_RETRY_MAX_ATTEMPTS = "fs.s3a.retry.limit" + val SPARK_S3_RETRY_MAX_ATTEMPTS: String = HADOOP_PREFIX + S3_RETRY_MAX_ATTEMPTS + val S3_CONNECTION_MAXIMUM = "fs.s3a.connection.maximum" + val SPARK_S3_CONNECTION_MAXIMUM: String = HADOOP_PREFIX + S3_CONNECTION_MAXIMUM // Hardware acceleraters backend val GLUTEN_SHUFFLE_CODEC_BACKEND = "spark.gluten.sql.columnar.shuffle.codecBackend" @@ -545,6 +554,7 @@ object GlutenConfig { val GLUTEN_DEBUG_MODE = "spark.gluten.sql.debug" val GLUTEN_DEBUG_KEEP_JNI_WORKSPACE = "spark.gluten.sql.debug.keepJniWorkspace" + val GLUTEN_DEBUG_KEEP_JNI_WORKSPACE_DIR = "spark.gluten.sql.debug.keepJniWorkspaceDir" // Added back to Spark Conf during executor initialization val GLUTEN_NUM_TASK_SLOTS_PER_EXECUTOR_KEY = "spark.gluten.numTaskSlotsPerExecutor" @@ -642,6 +652,10 @@ object GlutenConfig { SPARK_S3_USE_INSTANCE_CREDENTIALS, SPARK_S3_IAM, SPARK_S3_IAM_SESSION_NAME, + SPARK_S3_RETRY_MAX_ATTEMPTS, + SPARK_S3_CONNECTION_MAXIMUM, + AWS_S3_CONNECT_TIMEOUT.key, + AWS_S3_RETRY_MODE.key, AWS_SDK_LOG_LEVEL.key, // gcs config SPARK_GCS_STORAGE_ROOT_URL, @@ -693,6 +707,10 @@ object GlutenConfig { (SPARK_S3_USE_INSTANCE_CREDENTIALS, "false"), (SPARK_S3_IAM, ""), (SPARK_S3_IAM_SESSION_NAME, ""), + (SPARK_S3_RETRY_MAX_ATTEMPTS, "20"), + (SPARK_S3_CONNECTION_MAXIMUM, "15"), + (AWS_S3_CONNECT_TIMEOUT.key, AWS_S3_CONNECT_TIMEOUT.defaultValueString), + (AWS_S3_RETRY_MODE.key, AWS_S3_RETRY_MODE.defaultValueString), ( COLUMNAR_VELOX_CONNECTOR_IO_THREADS.key, conf.getOrElse(GLUTEN_NUM_TASK_SLOTS_PER_EXECUTOR_KEY, "-1")), @@ -718,7 +736,9 @@ object GlutenConfig { GLUTEN_OFFHEAP_SIZE_IN_BYTES_KEY, GLUTEN_TASK_OFFHEAP_SIZE_IN_BYTES_KEY, - GLUTEN_OFFHEAP_ENABLED + GLUTEN_OFFHEAP_ENABLED, + SESSION_LOCAL_TIMEZONE.key, + DECIMAL_OPERATIONS_ALLOW_PREC_LOSS.key ) nativeConfMap.putAll(conf.filter(e => keys.contains(e._1)).asJava) @@ -735,10 +755,6 @@ object GlutenConfig { .filter(_._1.startsWith(SPARK_ABFS_ACCOUNT_KEY)) .foreach(entry => nativeConfMap.put(entry._1, entry._2)) - conf - .filter(_._1.startsWith(SQLConf.SESSION_LOCAL_TIMEZONE.key)) - .foreach(entry => nativeConfMap.put(entry._1, entry._2)) - // return nativeConfMap } @@ -1566,11 +1582,17 @@ object GlutenConfig { .createWithDefault(false) val DEBUG_KEEP_JNI_WORKSPACE = - buildConf(GLUTEN_DEBUG_KEEP_JNI_WORKSPACE) + buildStaticConf(GLUTEN_DEBUG_KEEP_JNI_WORKSPACE) .internal() .booleanConf .createWithDefault(false) + val DEBUG_KEEP_JNI_WORKSPACE_DIR = + buildStaticConf(GLUTEN_DEBUG_KEEP_JNI_WORKSPACE_DIR) + .internal() + .stringConf + .createWithDefault("/tmp") + val BENCHMARK_TASK_STAGEID = buildConf("spark.gluten.sql.benchmark_task.stageId") .internal() @@ -1943,6 +1965,20 @@ object GlutenConfig { .stringConf .createWithDefault("FATAL") + val AWS_S3_RETRY_MODE = + buildConf("spark.gluten.velox.fs.s3a.retry.mode") + .internal() + .doc("Retry mode for AWS s3 connection error: legacy, standard and adaptive.") + .stringConf + .createWithDefault("legacy") + + val AWS_S3_CONNECT_TIMEOUT = + buildConf("spark.gluten.velox.fs.s3a.connect.timeout") + .internal() + .doc("Timeout for AWS s3 connection.") + .stringConf + .createWithDefault("200s") + val VELOX_ORC_SCAN_ENABLED = buildStaticConf("spark.gluten.sql.columnar.backend.velox.orc.scan.enabled") .internal() @@ -2015,4 +2051,12 @@ object GlutenConfig { .doubleConf .checkValue(v => v >= 0 && v <= 1, "offheap sizing memory fraction must between [0, 1]") .createWithDefault(0.6) + + val CELEBORN_FALLBACK_ENABLED = + buildStaticConf("spark.gluten.sql.columnar.shuffle.celeborn.fallback.enabled") + .internal() + .doc("If enabled, fall back to ColumnarShuffleManager when celeborn service is unavailable." + + "Otherwise, throw an exception.") + .booleanConf + .createWithDefault(true) } diff --git a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala index 112fa677d2cd..8317e28b58bb 100644 --- a/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala +++ b/shims/common/src/main/scala/org/apache/gluten/expression/ExpressionNames.scala @@ -83,10 +83,10 @@ object ExpressionNames { final val IS_NAN = "isnan" final val NANVL = "nanvl" final val TRY_EVAL = "try" - final val CHECK_ADD = "check_add" - final val CHECK_SUBTRACT = "check_subtract" - final val CHECK_DIVIDE = "check_divide" - final val CHECK_MULTIPLY = "check_multiply" + final val CHECKED_ADD = "checked_add" + final val CHECKED_SUBTRACT = "checked_subtract" + final val CHECKED_DIVIDE = "checked_divide" + final val CHECKED_MULTIPLY = "checked_multiply" // SparkSQL String functions final val ASCII = "ascii" @@ -127,6 +127,7 @@ object ExpressionNames { final val ENCODE = "encode" final val UUID = "uuid" final val BIT_LENGTH = "bit_length" + final val LEVENSHTEIN = "levenshteinDistance" final val UNBASE64 = "unbase64" final val BASE64 = "base64" diff --git a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala index 95571f166ebe..f6feae01a8b2 100644 --- a/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala +++ b/shims/spark35/src/main/scala/org/apache/gluten/sql/shims/spark35/Spark35Shims.scala @@ -76,7 +76,9 @@ class Spark35Shims extends SparkShims { Sig[SplitPart](ExpressionNames.SPLIT_PART), Sig[Sec](ExpressionNames.SEC), Sig[Csc](ExpressionNames.CSC), - Sig[Empty2Null](ExpressionNames.EMPTY2NULL)) + Sig[KnownNullable](ExpressionNames.KNOWN_NULLABLE), + Sig[Empty2Null](ExpressionNames.EMPTY2NULL) + ) } override def aggregateExpressionMappings: Seq[Sig] = { diff --git a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/Constants.scala b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/Constants.scala index 50766f3a91d1..e680ce9d5dda 100644 --- a/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/Constants.scala +++ b/tools/gluten-it/common/src/main/scala/org/apache/gluten/integration/Constants.scala @@ -44,6 +44,7 @@ object Constants { val VELOX_WITH_CELEBORN_CONF: SparkConf = new SparkConf(false) .set("spark.gluten.sql.columnar.forceShuffledHashJoin", "true") + .set("spark.gluten.sql.columnar.shuffle.celeborn.fallback.enabled", "false") .set("spark.sql.parquet.enableVectorizedReader", "true") .set("spark.plugins", "org.apache.gluten.GlutenPlugin") .set( diff --git a/tools/gluten-it/pom.xml b/tools/gluten-it/pom.xml index 3f1760069792..71db637a8403 100644 --- a/tools/gluten-it/pom.xml +++ b/tools/gluten-it/pom.xml @@ -21,7 +21,7 @@ 3.4.2 2.12 3 - 0.3.0-incubating + 0.3.2-incubating 0.8.0 1.2.0-SNAPSHOT 32.0.1-jre @@ -167,7 +167,7 @@ celeborn-0.4 - 0.4.0-incubating + 0.4.1