From 1fbdbc41779321db3380bce0807b73389af64e1a Mon Sep 17 00:00:00 2001 From: Chang chen Date: Tue, 25 Jun 2024 07:13:39 +0800 Subject: [PATCH] [GLUTEN-6067][CH] [Part 2] Support CH backend with Spark3.5 - Prepare for supporting sink transform (#6197) [CH] [Part 2] Support CH backend with Spark3.5 - Prepare for supporting sink transform * [Refactor] remove duplicate codes * Add NativeWriteChecker * [Prepare to commit] getExtendedColumnarPostRules from Spark shim --- .../clickhouse/CHIteratorApi.scala | 143 ++-- .../clickhouse/CHSparkPlanExecApi.scala | 9 - .../execution/CHHashJoinExecTransformer.scala | 3 +- ...lutenClickHouseNativeWriteTableSuite.scala | 612 ++++++++---------- .../GlutenClickHouseTPCHMetricsSuite.scala | 2 +- .../spark/gluten/NativeWriteChecker.scala | 52 ++ .../velox/VeloxSparkPlanExecApi.scala | 9 - cpp-ch/local-engine/Common/CHUtil.cpp | 17 +- cpp-ch/local-engine/Common/CHUtil.h | 12 +- .../Parser/CHColumnToSparkRow.cpp | 2 +- .../Parser/SerializedPlanParser.cpp | 310 ++++----- .../Parser/SerializedPlanParser.h | 39 +- cpp-ch/local-engine/local_engine_jni.cpp | 39 +- .../tests/benchmark_local_engine.cpp | 80 +-- cpp-ch/local-engine/tests/gluten_test_util.h | 18 + .../local-engine/tests/gtest_local_engine.cpp | 22 +- cpp-ch/local-engine/tests/gtest_parser.cpp | 407 ++++-------- .../tests/json/clickhouse_pr_65234.json | 273 ++++++++ .../tests/json/gtest_local_engine_config.json | 269 ++++++++ .../json/read_student_option_schema.csv.json | 77 +++ .../gluten/backendsapi/SparkPlanExecApi.scala | 4 +- .../utils/SubstraitPlanPrinterUtil.scala | 35 +- 22 files changed, 1379 insertions(+), 1055 deletions(-) create mode 100644 backends-clickhouse/src/test/scala/org/apache/spark/gluten/NativeWriteChecker.scala create mode 100644 cpp-ch/local-engine/tests/json/clickhouse_pr_65234.json create mode 100644 cpp-ch/local-engine/tests/json/gtest_local_engine_config.json create mode 100644 cpp-ch/local-engine/tests/json/read_student_option_schema.csv.json diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala index 941237629569..376e46ebe975 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala @@ -16,7 +16,7 @@ */ package org.apache.gluten.backendsapi.clickhouse -import org.apache.gluten.{GlutenConfig, GlutenNumaBindingInfo} +import org.apache.gluten.GlutenNumaBindingInfo import org.apache.gluten.backendsapi.IteratorApi import org.apache.gluten.execution._ import org.apache.gluten.expression.ConverterUtils @@ -61,6 +61,52 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil { StructType(dataSchema) } + private def createNativeIterator( + splitInfoByteArray: Array[Array[Byte]], + wsPlan: Array[Byte], + materializeInput: Boolean, + inputIterators: Seq[Iterator[ColumnarBatch]]): BatchIterator = { + + /** Generate closeable ColumnBatch iterator. */ + val listIterator = + inputIterators + .map { + case i: CloseableCHColumnBatchIterator => i + case it => new CloseableCHColumnBatchIterator(it) + } + .map(it => new ColumnarNativeIterator(it.asJava).asInstanceOf[GeneralInIterator]) + .asJava + new CHNativeExpressionEvaluator().createKernelWithBatchIterator( + wsPlan, + splitInfoByteArray, + listIterator, + materializeInput + ) + } + + private def createCloseIterator( + context: TaskContext, + pipelineTime: SQLMetric, + updateNativeMetrics: IMetrics => Unit, + updateInputMetrics: Option[InputMetricsWrapper => Unit] = None, + nativeIter: BatchIterator): CloseableCHColumnBatchIterator = { + + val iter = new CollectMetricIterator( + nativeIter, + updateNativeMetrics, + updateInputMetrics, + updateInputMetrics.map(_ => context.taskMetrics().inputMetrics).orNull) + + context.addTaskFailureListener( + (ctx, _) => { + if (ctx.isInterrupted()) { + iter.cancel() + } + }) + context.addTaskCompletionListener[Unit](_ => iter.close()) + new CloseableCHColumnBatchIterator(iter, Some(pipelineTime)) + } + // only set file schema for text format table private def setFileSchemaForLocalFiles( localFilesNode: LocalFilesNode, @@ -198,45 +244,24 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil { inputIterators: Seq[Iterator[ColumnarBatch]] = Seq() ): Iterator[ColumnarBatch] = { - assert( + require( inputPartition.isInstanceOf[GlutenPartition], "CH backend only accepts GlutenPartition in GlutenWholeStageColumnarRDD.") - - val transKernel = new CHNativeExpressionEvaluator() - val inBatchIters = new JArrayList[GeneralInIterator](inputIterators.map { - iter => new ColumnarNativeIterator(CHIteratorApi.genCloseableColumnBatchIterator(iter).asJava) - }.asJava) - val splitInfoByteArray = inputPartition .asInstanceOf[GlutenPartition] .splitInfosByteArray - val nativeIter = - transKernel.createKernelWithBatchIterator( - inputPartition.plan, - splitInfoByteArray, - inBatchIters, - false) + val wsPlan = inputPartition.plan + val materializeInput = false - val iter = new CollectMetricIterator( - nativeIter, - updateNativeMetrics, - updateInputMetrics, - context.taskMetrics().inputMetrics) - - context.addTaskFailureListener( - (ctx, _) => { - if (ctx.isInterrupted()) { - iter.cancel() - } - }) - context.addTaskCompletionListener[Unit](_ => iter.close()) - - // TODO: SPARK-25083 remove the type erasure hack in data source scan new InterruptibleIterator( context, - new CloseableCHColumnBatchIterator( - iter.asInstanceOf[Iterator[ColumnarBatch]], - Some(pipelineTime))) + createCloseIterator( + context, + pipelineTime, + updateNativeMetrics, + Some(updateInputMetrics), + createNativeIterator(splitInfoByteArray, wsPlan, materializeInput, inputIterators)) + ) } // Generate Iterator[ColumnarBatch] for final stage. @@ -252,52 +277,26 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil { partitionIndex: Int, materializeInput: Boolean): Iterator[ColumnarBatch] = { // scalastyle:on argcount - GlutenConfig.getConf - - val transKernel = new CHNativeExpressionEvaluator() - val columnarNativeIterator = - new JArrayList[GeneralInIterator](inputIterators.map { - iter => - new ColumnarNativeIterator(CHIteratorApi.genCloseableColumnBatchIterator(iter).asJava) - }.asJava) - // we need to complete dependency RDD's firstly - val nativeIterator = transKernel.createKernelWithBatchIterator( - rootNode.toProtobuf.toByteArray, - // Final iterator does not contain scan split, so pass empty split info to native here. - new Array[Array[Byte]](0), - columnarNativeIterator, - materializeInput - ) - - val iter = new CollectMetricIterator(nativeIterator, updateNativeMetrics, null, null) - context.addTaskFailureListener( - (ctx, _) => { - if (ctx.isInterrupted()) { - iter.cancel() - } - }) - context.addTaskCompletionListener[Unit](_ => iter.close()) - new CloseableCHColumnBatchIterator(iter, Some(pipelineTime)) - } -} + // Final iterator does not contain scan split, so pass empty split info to native here. + val splitInfoByteArray = new Array[Array[Byte]](0) + val wsPlan = rootNode.toProtobuf.toByteArray -object CHIteratorApi { - - /** Generate closeable ColumnBatch iterator. */ - def genCloseableColumnBatchIterator(iter: Iterator[ColumnarBatch]): Iterator[ColumnarBatch] = { - iter match { - case _: CloseableCHColumnBatchIterator => iter - case _ => new CloseableCHColumnBatchIterator(iter) - } + // we need to complete dependency RDD's firstly + createCloseIterator( + context, + pipelineTime, + updateNativeMetrics, + None, + createNativeIterator(splitInfoByteArray, wsPlan, materializeInput, inputIterators)) } } class CollectMetricIterator( val nativeIterator: BatchIterator, val updateNativeMetrics: IMetrics => Unit, - val updateInputMetrics: InputMetricsWrapper => Unit, - val inputMetrics: InputMetrics + val updateInputMetrics: Option[InputMetricsWrapper => Unit] = None, + val inputMetrics: InputMetrics = null ) extends Iterator[ColumnarBatch] { private var outputRowCount = 0L private var outputVectorCount = 0L @@ -329,9 +328,7 @@ class CollectMetricIterator( val nativeMetrics = nativeIterator.getMetrics.asInstanceOf[NativeMetrics] nativeMetrics.setFinalOutputMetrics(outputRowCount, outputVectorCount) updateNativeMetrics(nativeMetrics) - if (updateInputMetrics != null) { - updateInputMetrics(inputMetrics) - } + updateInputMetrics.foreach(_(inputMetrics)) metricsUpdated = true } } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala index 1c83e326eed4..ac3ea61ff810 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -50,7 +50,6 @@ import org.apache.spark.sql.delta.files.TahoeFileIndex import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec import org.apache.spark.sql.execution.datasources.{FileFormat, HadoopFsRelation} -import org.apache.spark.sql.execution.datasources.GlutenWriterColumnarRules.NativeWritePostRule import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.datasources.v2.clickhouse.source.DeltaMergeTreeFileFormat import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} @@ -583,14 +582,6 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { override def genExtendedColumnarTransformRules(): List[SparkSession => Rule[SparkPlan]] = List() - /** - * Generate extended columnar post-rules. - * - * @return - */ - override def genExtendedColumnarPostRules(): List[SparkSession => Rule[SparkPlan]] = - List(spark => NativeWritePostRule(spark)) - override def genInjectPostHocResolutionRules(): List[SparkSession => Rule[LogicalPlan]] = { List() } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala index a7e7769e7736..da9d9c7586c0 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/execution/CHHashJoinExecTransformer.scala @@ -16,7 +16,6 @@ */ package org.apache.gluten.execution -import org.apache.gluten.backendsapi.clickhouse.CHIteratorApi import org.apache.gluten.extension.ValidationResult import org.apache.gluten.utils.{BroadcastHashJoinStrategy, CHJoinValidateUtil, ShuffleHashJoinStrategy} @@ -75,7 +74,7 @@ case class CHBroadcastBuildSideRDD( override def genBroadcastBuildSideIterator(): Iterator[ColumnarBatch] = { CHBroadcastBuildSideCache.getOrBuildBroadcastHashTable(broadcasted, broadcastContext) - CHIteratorApi.genCloseableColumnBatchIterator(Iterator.empty) + Iterator.empty } } diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala index 9269303d9251..ccf7bb5d5b2a 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala @@ -21,6 +21,7 @@ import org.apache.gluten.execution.AllDataTypesWithComplexType.genTestData import org.apache.gluten.utils.UTSystemParameters import org.apache.spark.SparkConf +import org.apache.spark.gluten.NativeWriteChecker import org.apache.spark.sql.SparkSession import org.apache.spark.sql.delta.DeltaLog import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -28,11 +29,14 @@ import org.apache.spark.sql.test.SharedSparkSession import org.scalatest.BeforeAndAfterAll +import scala.reflect.runtime.universe.TypeTag + class GlutenClickHouseNativeWriteTableSuite extends GlutenClickHouseWholeStageTransformerSuite with AdaptiveSparkPlanHelper with SharedSparkSession - with BeforeAndAfterAll { + with BeforeAndAfterAll + with NativeWriteChecker { private var _hiveSpark: SparkSession = _ @@ -114,16 +118,19 @@ class GlutenClickHouseNativeWriteTableSuite def getColumnName(s: String): String = { s.replaceAll("\\(", "_").replaceAll("\\)", "_") } + import collection.immutable.ListMap import java.io.File def writeIntoNewTableWithSql(table_name: String, table_create_sql: String)( fields: Seq[String]): Unit = { - spark.sql(table_create_sql) - spark.sql( - s"insert overwrite $table_name select ${fields.mkString(",")}" + - s" from origin_table") + withDestinationTable(table_name, table_create_sql) { + checkNativeWrite( + s"insert overwrite $table_name select ${fields.mkString(",")}" + + s" from origin_table", + checkNative = true) + } } def writeAndCheckRead( @@ -170,82 +177,86 @@ class GlutenClickHouseNativeWriteTableSuite }) } - test("test insert into dir") { - withSQLConf( - ("spark.gluten.sql.native.writer.enabled", "true"), - (GlutenConfig.GLUTEN_ENABLED.key, "true")) { - - val originDF = spark.createDataFrame(genTestData()) - originDF.createOrReplaceTempView("origin_table") + private val fields_ = ListMap( + ("string_field", "string"), + ("int_field", "int"), + ("long_field", "long"), + ("float_field", "float"), + ("double_field", "double"), + ("short_field", "short"), + ("byte_field", "byte"), + ("boolean_field", "boolean"), + ("decimal_field", "decimal(23,12)"), + ("date_field", "date") + ) - val fields: ListMap[String, String] = ListMap( - ("string_field", "string"), - ("int_field", "int"), - ("long_field", "long"), - ("float_field", "float"), - ("double_field", "double"), - ("short_field", "short"), - ("byte_field", "byte"), - ("boolean_field", "boolean"), - ("decimal_field", "decimal(23,12)"), - ("date_field", "date") - ) + def withDestinationTable(table: String, createTableSql: String)(f: => Unit): Unit = { + spark.sql(s"drop table IF EXISTS $table") + spark.sql(s"$createTableSql") + f + } - for (format <- formats) { - spark.sql( - s"insert overwrite local directory '$basePath/test_insert_into_${format}_dir1' " - + s"stored as $format select " - + fields.keys.mkString(",") + - " from origin_table cluster by (byte_field)") - spark.sql( - s"insert overwrite local directory '$basePath/test_insert_into_${format}_dir2' " + - s"stored as $format " + - "select string_field, sum(int_field) as x from origin_table group by string_field") - } + def nativeWrite(f: String => Unit): Unit = { + withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) { + formats.foreach(f(_)) } } - test("test insert into partition") { - withSQLConf( - ("spark.gluten.sql.native.writer.enabled", "true"), - ("spark.sql.orc.compression.codec", "lz4"), - (GlutenConfig.GLUTEN_ENABLED.key, "true")) { - - val originDF = spark.createDataFrame(genTestData()) - originDF.createOrReplaceTempView("origin_table") - - val fields: ListMap[String, String] = ListMap( - ("string_field", "string"), - ("int_field", "int"), - ("long_field", "long"), - ("float_field", "float"), - ("double_field", "double"), - ("short_field", "short"), - ("byte_field", "byte"), - ("boolean_field", "boolean"), - ("decimal_field", "decimal(23,12)"), - ("date_field", "date") - ) - - for (format <- formats) { - val table_name = table_name_template.format(format) - spark.sql(s"drop table IF EXISTS $table_name") + def nativeWrite2( + f: String => (String, String, String), + extraCheck: (String, String, String) => Unit = null): Unit = nativeWrite { + format => + val (table_name, table_create_sql, insert_sql) = f(format) + withDestinationTable(table_name, table_create_sql) { + checkNativeWrite(insert_sql, checkNative = true) + Option(extraCheck).foreach(_(table_name, table_create_sql, insert_sql)) + } + } - val table_create_sql = - s"create table if not exists $table_name (" + - fields - .map(f => s"${f._1} ${f._2}") - .mkString(",") + - " ) partitioned by (another_date_field date) " + - s"stored as $format" + def nativeWriteWithOriginalView[A <: Product: TypeTag]( + data: Seq[A], + viewName: String, + pairs: (String, String)*)(f: String => Unit): Unit = { + val configs = pairs :+ ("spark.gluten.sql.native.writer.enabled", "true") + withSQLConf(configs: _*) { + withTempView(viewName) { + spark.createDataFrame(data).createOrReplaceTempView(viewName) + formats.foreach(f(_)) + } + } + } - spark.sql(table_create_sql) + test("test insert into dir") { + nativeWriteWithOriginalView(genTestData(), "origin_table") { + format => + Seq( + s"""insert overwrite local directory '$basePath/test_insert_into_${format}_dir1' + |stored as $format select ${fields_.keys.mkString(",")} + |from origin_table""".stripMargin, + s"""insert overwrite local directory '$basePath/test_insert_into_${format}_dir2' + |stored as $format select string_field, sum(int_field) as x + |from origin_table group by string_field""".stripMargin + ).foreach(checkNativeWrite(_, checkNative = true)) + } + } - spark.sql( - s"insert into $table_name partition(another_date_field = '2020-01-01') select " - + fields.keys.mkString(",") + - " from origin_table") + test("test insert into partition") { + def destination(format: String): (String, String, String) = { + val table_name = table_name_template.format(format) + val table_create_sql = + s"""create table if not exists $table_name + |(${fields_.map(f => s"${f._1} ${f._2}").mkString(",")}) + |partitioned by (another_date_field date) stored as $format""".stripMargin + val insert_sql = + s"""insert into $table_name partition(another_date_field = '2020-01-01') + | select ${fields_.keys.mkString(",")} from origin_table""".stripMargin + (table_name, table_create_sql, insert_sql) + } + def nativeFormatWrite(format: String): Unit = { + val (table_name, table_create_sql, insert_sql) = destination(format) + withDestinationTable(table_name, table_create_sql) { + checkNativeWrite(insert_sql, checkNative = true) var files = recursiveListFiles(new File(getWarehouseDir + "/" + table_name)) .filter(_.getName.endsWith(s".$format")) if (format == "orc") { @@ -255,154 +266,103 @@ class GlutenClickHouseNativeWriteTableSuite assert(files.head.getAbsolutePath.contains("another_date_field=2020-01-01")) } } + + nativeWriteWithOriginalView( + genTestData(), + "origin_table", + ("spark.sql.orc.compression.codec", "lz4"))(nativeFormatWrite) } test("test CTAS") { - withSQLConf( - ("spark.gluten.sql.native.writer.enabled", "true"), - (GlutenConfig.GLUTEN_ENABLED.key, "true")) { - - val originDF = spark.createDataFrame(genTestData()) - originDF.createOrReplaceTempView("origin_table") - val fields: ListMap[String, String] = ListMap( - ("string_field", "string"), - ("int_field", "int"), - ("long_field", "long"), - ("float_field", "float"), - ("double_field", "double"), - ("short_field", "short"), - ("byte_field", "byte"), - ("boolean_field", "boolean"), - ("decimal_field", "decimal(23,12)"), - ("date_field", "date") - ) - - for (format <- formats) { + nativeWriteWithOriginalView(genTestData(), "origin_table") { + format => val table_name = table_name_template.format(format) - spark.sql(s"drop table IF EXISTS $table_name") val table_create_sql = s"create table $table_name using $format as select " + - fields + fields_ .map(f => s"${f._1}") .mkString(",") + " from origin_table" - spark.sql(table_create_sql) - spark.sql(s"drop table IF EXISTS $table_name") + val insert_sql = + s"create table $table_name as select " + + fields_ + .map(f => s"${f._1}") + .mkString(",") + + " from origin_table" + withDestinationTable(table_name, table_create_sql) { + spark.sql(s"drop table IF EXISTS $table_name") - try { - val table_create_sql = - s"create table $table_name as select " + - fields - .map(f => s"${f._1}") - .mkString(",") + - " from origin_table" - spark.sql(table_create_sql) - } catch { - case _: UnsupportedOperationException => // expected - case _: Exception => fail("should not throw exception") + try { + // FIXME: using checkNativeWrite + spark.sql(insert_sql) + } catch { + case _: UnsupportedOperationException => // expected + case e: Exception => fail("should not throw exception", e) + } } - } } } test("test insert into partition, bigo's case which incur InsertIntoHiveTable") { - withSQLConf( - ("spark.gluten.sql.native.writer.enabled", "true"), - ("spark.sql.hive.convertMetastoreParquet", "false"), - ("spark.sql.hive.convertMetastoreOrc", "false"), - (GlutenConfig.GLUTEN_ENABLED.key, "true") - ) { - - val originDF = spark.createDataFrame(genTestData()) - originDF.createOrReplaceTempView("origin_table") - val fields: ListMap[String, String] = ListMap( - ("string_field", "string"), - ("int_field", "int"), - ("long_field", "long"), - ("float_field", "float"), - ("double_field", "double"), - ("short_field", "short"), - ("byte_field", "byte"), - ("boolean_field", "boolean"), - ("decimal_field", "decimal(23,12)"), - ("date_field", "date") - ) - - for (format <- formats) { - val table_name = table_name_template.format(format) - spark.sql(s"drop table IF EXISTS $table_name") - val table_create_sql = s"create table if not exists $table_name (" + fields - .map(f => s"${f._1} ${f._2}") - .mkString(",") + " ) partitioned by (another_date_field string)" + - s"stored as $format" + def destination(format: String): (String, String, String) = { + val table_name = table_name_template.format(format) + val table_create_sql = s"create table if not exists $table_name (" + fields_ + .map(f => s"${f._1} ${f._2}") + .mkString(",") + " ) partitioned by (another_date_field string)" + + s"stored as $format" + val insert_sql = + s"insert overwrite table $table_name " + + "partition(another_date_field = '2020-01-01') select " + + fields_.keys.mkString(",") + " from (select " + fields_.keys.mkString( + ",") + ", row_number() over (order by int_field desc) as rn " + + "from origin_table where float_field > 3 ) tt where rn <= 100" + (table_name, table_create_sql, insert_sql) + } - spark.sql(table_create_sql) - spark.sql( - s"insert overwrite table $table_name " + - "partition(another_date_field = '2020-01-01') select " - + fields.keys.mkString(",") + " from (select " + fields.keys.mkString( - ",") + ", row_number() over (order by int_field desc) as rn " + - "from origin_table where float_field > 3 ) tt where rn <= 100") + def nativeFormatWrite(format: String): Unit = { + val (table_name, table_create_sql, insert_sql) = destination(format) + withDestinationTable(table_name, table_create_sql) { + checkNativeWrite(insert_sql, checkNative = true) val files = recursiveListFiles(new File(getWarehouseDir + "/" + table_name)) .filter(_.getName.startsWith("part")) assert(files.length == 1) assert(files.head.getAbsolutePath.contains("another_date_field=2020-01-01")) } } + + nativeWriteWithOriginalView( + genTestData(), + "origin_table", + ("spark.sql.hive.convertMetastoreParquet", "false"), + ("spark.sql.hive.convertMetastoreOrc", "false"))(nativeFormatWrite) } test("test 1-col partitioned table") { + nativeWrite { + format => + { + val table_name = table_name_template.format(format) + val table_create_sql = + s"create table if not exists $table_name (" + + fields_ + .filterNot(e => e._1.equals("date_field")) + .map(f => s"${f._1} ${f._2}") + .mkString(",") + + " ) partitioned by (date_field date) " + + s"stored as $format" - withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) { - - val fields: ListMap[String, String] = ListMap( - ("string_field", "string"), - ("int_field", "int"), - ("long_field", "long"), - ("float_field", "float"), - ("double_field", "double"), - ("short_field", "short"), - ("byte_field", "byte"), - ("boolean_field", "boolean"), - ("decimal_field", "decimal(23,12)"), - ("date_field", "date") - ) - - for (format <- formats) { - val table_name = table_name_template.format(format) - val table_create_sql = - s"create table if not exists $table_name (" + - fields - .filterNot(e => e._1.equals("date_field")) - .map(f => s"${f._1} ${f._2}") - .mkString(",") + - " ) partitioned by (date_field date) " + - s"stored as $format" - - writeAndCheckRead( - table_name, - writeIntoNewTableWithSql(table_name, table_create_sql), - fields.keys.toSeq) - } + writeAndCheckRead( + table_name, + writeIntoNewTableWithSql(table_name, table_create_sql), + fields_.keys.toSeq) + } } } // even if disable native writer, this UT fail, spark bug??? ignore("test 1-col partitioned table, partitioned by already ordered column") { withSQLConf(("spark.gluten.sql.native.writer.enabled", "false")) { - val fields: ListMap[String, String] = ListMap( - ("string_field", "string"), - ("int_field", "int"), - ("long_field", "long"), - ("float_field", "float"), - ("double_field", "double"), - ("short_field", "short"), - ("byte_field", "byte"), - ("boolean_field", "boolean"), - ("decimal_field", "decimal(23,12)"), - ("date_field", "date") - ) val originDF = spark.createDataFrame(genTestData()) originDF.createOrReplaceTempView("origin_table") @@ -410,7 +370,7 @@ class GlutenClickHouseNativeWriteTableSuite val table_name = table_name_template.format(format) val table_create_sql = s"create table if not exists $table_name (" + - fields + fields_ .filterNot(e => e._1.equals("date_field")) .map(f => s"${f._1} ${f._2}") .mkString(",") + @@ -420,31 +380,27 @@ class GlutenClickHouseNativeWriteTableSuite spark.sql(s"drop table IF EXISTS $table_name") spark.sql(table_create_sql) spark.sql( - s"insert overwrite $table_name select ${fields.mkString(",")}" + + s"insert overwrite $table_name select ${fields_.mkString(",")}" + s" from origin_table order by date_field") } } } test("test 2-col partitioned table") { - withSQLConf( - ("spark.gluten.sql.native.writer.enabled", "true"), - (GlutenConfig.GLUTEN_ENABLED.key, "true")) { - - val fields: ListMap[String, String] = ListMap( - ("string_field", "string"), - ("int_field", "int"), - ("long_field", "long"), - ("float_field", "float"), - ("double_field", "double"), - ("short_field", "short"), - ("boolean_field", "boolean"), - ("decimal_field", "decimal(23,12)"), - ("date_field", "date"), - ("byte_field", "byte") - ) - - for (format <- formats) { + val fields: ListMap[String, String] = ListMap( + ("string_field", "string"), + ("int_field", "int"), + ("long_field", "long"), + ("float_field", "float"), + ("double_field", "double"), + ("short_field", "short"), + ("boolean_field", "boolean"), + ("decimal_field", "decimal(23,12)"), + ("date_field", "date"), + ("byte_field", "byte") + ) + nativeWrite { + format => val table_name = table_name_template.format(format) val table_create_sql = s"create table if not exists $table_name (" + @@ -458,7 +414,6 @@ class GlutenClickHouseNativeWriteTableSuite table_name, writeIntoNewTableWithSql(table_name, table_create_sql), fields.keys.toSeq) - } } } @@ -506,25 +461,21 @@ class GlutenClickHouseNativeWriteTableSuite // This test case will be failed with incorrect result randomly, ignore first. ignore("test hive parquet/orc table, all columns being partitioned. ") { - withSQLConf( - ("spark.gluten.sql.native.writer.enabled", "true"), - (GlutenConfig.GLUTEN_ENABLED.key, "true")) { - - val fields: ListMap[String, String] = ListMap( - ("date_field", "date"), - ("timestamp_field", "timestamp"), - ("string_field", "string"), - ("int_field", "int"), - ("long_field", "long"), - ("float_field", "float"), - ("double_field", "double"), - ("short_field", "short"), - ("byte_field", "byte"), - ("boolean_field", "boolean"), - ("decimal_field", "decimal(23,12)") - ) - - for (format <- formats) { + val fields: ListMap[String, String] = ListMap( + ("date_field", "date"), + ("timestamp_field", "timestamp"), + ("string_field", "string"), + ("int_field", "int"), + ("long_field", "long"), + ("float_field", "float"), + ("double_field", "double"), + ("short_field", "short"), + ("byte_field", "byte"), + ("boolean_field", "boolean"), + ("decimal_field", "decimal(23,12)") + ) + nativeWrite { + format => val table_name = table_name_template.format(format) val table_create_sql = s"create table if not exists $table_name (" + @@ -540,20 +491,15 @@ class GlutenClickHouseNativeWriteTableSuite table_name, writeIntoNewTableWithSql(table_name, table_create_sql), fields.keys.toSeq) - } } } - test(("test hive parquet/orc table with aggregated results")) { - withSQLConf( - ("spark.gluten.sql.native.writer.enabled", "true"), - (GlutenConfig.GLUTEN_ENABLED.key, "true")) { - - val fields: ListMap[String, String] = ListMap( - ("sum(int_field)", "bigint") - ) - - for (format <- formats) { + test("test hive parquet/orc table with aggregated results") { + val fields: ListMap[String, String] = ListMap( + ("sum(int_field)", "bigint") + ) + nativeWrite { + format => val table_name = table_name_template.format(format) val table_create_sql = s"create table if not exists $table_name (" + @@ -566,29 +512,12 @@ class GlutenClickHouseNativeWriteTableSuite table_name, writeIntoNewTableWithSql(table_name, table_create_sql), fields.keys.toSeq) - } } } test("test 1-col partitioned + 1-col bucketed table") { - withSQLConf( - ("spark.gluten.sql.native.writer.enabled", "true"), - (GlutenConfig.GLUTEN_ENABLED.key, "true")) { - - val fields: ListMap[String, String] = ListMap( - ("string_field", "string"), - ("int_field", "int"), - ("long_field", "long"), - ("float_field", "float"), - ("double_field", "double"), - ("short_field", "short"), - ("byte_field", "byte"), - ("boolean_field", "boolean"), - ("decimal_field", "decimal(23,12)"), - ("date_field", "date") - ) - - for (format <- formats) { + nativeWrite { + format => // spark write does not support bucketed table // https://issues.apache.org/jira/browse/SPARK-19256 val table_name = table_name_template.format(format) @@ -604,7 +533,7 @@ class GlutenClickHouseNativeWriteTableSuite .bucketBy(2, "byte_field") .saveAsTable(table_name) }, - fields.keys.toSeq + fields_.keys.toSeq ) assert( @@ -614,10 +543,8 @@ class GlutenClickHouseNativeWriteTableSuite .filter(!_.getName.equals("date_field=__HIVE_DEFAULT_PARTITION__")) .head .listFiles() - .filter(!_.isHidden) - .length == 2 + .count(!_.isHidden) == 2 ) // 2 bucket files - } } } @@ -745,8 +672,8 @@ class GlutenClickHouseNativeWriteTableSuite } test("test consecutive blocks having same partition value") { - withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) { - for (format <- formats) { + nativeWrite { + format => val table_name = table_name_template.format(format) spark.sql(s"drop table IF EXISTS $table_name") @@ -760,15 +687,14 @@ class GlutenClickHouseNativeWriteTableSuite .partitionBy("p") .saveAsTable(table_name) - val ret = spark.sql("select sum(id) from " + table_name).collect().apply(0).apply(0) + val ret = spark.sql(s"select sum(id) from $table_name").collect().apply(0).apply(0) assert(ret == 449985000) - } } } test("test decimal with rand()") { - withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) { - for (format <- formats) { + nativeWrite { + format => val table_name = table_name_template.format(format) spark.sql(s"drop table IF EXISTS $table_name") spark @@ -778,32 +704,30 @@ class GlutenClickHouseNativeWriteTableSuite .format(format) .partitionBy("p") .saveAsTable(table_name) - val ret = spark.sql("select max(p) from " + table_name).collect().apply(0).apply(0) - } + val ret = spark.sql(s"select max(p) from $table_name").collect().apply(0).apply(0) } } test("test partitioned by constant") { - withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) { - for (format <- formats) { - spark.sql(s"drop table IF EXISTS tmp_123_$format") - spark.sql( - s"create table tmp_123_$format(" + - s"x1 string, x2 bigint,x3 string, x4 bigint, x5 string )" + - s"partitioned by (day date) stored as $format") - - spark.sql( - s"insert into tmp_123_$format partition(day) " + - "select cast(id as string), id, cast(id as string), id, cast(id as string), " + - "'2023-05-09' from range(10000000)") - } + nativeWrite2 { + format => + val table_name = s"tmp_123_$format" + val create_sql = + s"""create table tmp_123_$format( + |x1 string, x2 bigint,x3 string, x4 bigint, x5 string ) + |partitioned by (day date) stored as $format""".stripMargin + val insert_sql = + s"""insert into tmp_123_$format partition(day) + |select cast(id as string), id, cast(id as string), + | id, cast(id as string), '2023-05-09' + |from range(10000000)""".stripMargin + (table_name, create_sql, insert_sql) } } test("test bucketed by constant") { - withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) { - - for (format <- formats) { + nativeWrite { + format => val table_name = table_name_template.format(format) spark.sql(s"drop table IF EXISTS $table_name") @@ -815,15 +739,13 @@ class GlutenClickHouseNativeWriteTableSuite .bucketBy(2, "p") .saveAsTable(table_name) - val ret = spark.sql("select count(*) from " + table_name).collect().apply(0).apply(0) - } + assertResult(10000000)(spark.table(table_name).count()) } } test("test consecutive null values being partitioned") { - withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) { - - for (format <- formats) { + nativeWrite { + format => val table_name = table_name_template.format(format) spark.sql(s"drop table IF EXISTS $table_name") @@ -835,14 +757,13 @@ class GlutenClickHouseNativeWriteTableSuite .partitionBy("p") .saveAsTable(table_name) - val ret = spark.sql("select count(*) from " + table_name).collect().apply(0).apply(0) - } + assertResult(30000)(spark.table(table_name).count()) } } test("test consecutive null values being bucketed") { - withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) { - for (format <- formats) { + nativeWrite { + format => val table_name = table_name_template.format(format) spark.sql(s"drop table IF EXISTS $table_name") @@ -854,78 +775,79 @@ class GlutenClickHouseNativeWriteTableSuite .bucketBy(2, "p") .saveAsTable(table_name) - val ret = spark.sql("select count(*) from " + table_name).collect().apply(0).apply(0) - } + assertResult(30000)(spark.table(table_name).count()) } } test("test native write with empty dataset") { - withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) { - for (format <- formats) { + nativeWrite2( + format => { val table_name = "t_" + format - spark.sql(s"drop table IF EXISTS $table_name") - spark.sql(s"create table $table_name (id int, str string) stored as $format") - spark.sql( - s"insert into $table_name select id, cast(id as string) from range(10)" + - " where id > 100") + ( + table_name, + s"create table $table_name (id int, str string) stored as $format", + s"insert into $table_name select id, cast(id as string) from range(10) where id > 100" + ) + }, + (table_name, _, _) => { + assertResult(0)(spark.table(table_name).count()) } - } + ) } test("test native write with union") { - withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) { - for (format <- formats) { + nativeWrite { + format => val table_name = "t_" + format - spark.sql(s"drop table IF EXISTS $table_name") - spark.sql(s"create table $table_name (id int, str string) stored as $format") - spark.sql( - s"insert overwrite table $table_name " + - "select id, cast(id as string) from range(10) union all " + - "select 10, '10' from range(10)") - spark.sql( - s"insert overwrite table $table_name " + - "select id, cast(id as string) from range(10) union all " + - "select 10, cast(id as string) from range(10)") - - } + withDestinationTable( + table_name, + s"create table $table_name (id int, str string) stored as $format") { + checkNativeWrite( + s"insert overwrite table $table_name " + + "select id, cast(id as string) from range(10) union all " + + "select 10, '10' from range(10)", + checkNative = true) + checkNativeWrite( + s"insert overwrite table $table_name " + + "select id, cast(id as string) from range(10) union all " + + "select 10, cast(id as string) from range(10)", + checkNative = true + ) + } } } test("test native write and non-native read consistency") { - withSQLConf(("spark.gluten.sql.native.writer.enabled", "true")) { - for (format <- formats) { - val table_name = "t_" + format - spark.sql(s"drop table IF EXISTS $table_name") - spark.sql(s"create table $table_name (id int, name string, info char(4)) stored as $format") - spark.sql( - s"insert overwrite table $table_name " + - "select id, cast(id as string), concat('aaa', cast(id as string)) from range(10)") + nativeWrite2( + { + format => + val table_name = "t_" + format + ( + table_name, + s"create table $table_name (id int, name string, info char(4)) stored as $format", + s"insert overwrite table $table_name " + + "select id, cast(id as string), concat('aaa', cast(id as string)) from range(10)" + ) + }, + (table_name, _, _) => compareResultsAgainstVanillaSpark( s"select * from $table_name", compareResult = true, _ => {}) - } - } + ) } test("GLUTEN-4316: fix crash on dynamic partition inserting") { - withSQLConf( - ("spark.gluten.sql.native.writer.enabled", "true"), - (GlutenConfig.GLUTEN_ENABLED.key, "true")) { - formats.foreach( - format => { - val tbl = "t_" + format - spark.sql(s"drop table IF EXISTS $tbl") - val sql1 = - s"create table $tbl(a int, b map, c struct) " + - s"partitioned by (day string) stored as $format" - val sql2 = s"insert overwrite $tbl partition (day) " + - s"select id as a, str_to_map(concat('t1:','a','&t2:','b'),'&',':'), " + - s"struct('1', null) as c, '2024-01-08' as day from range(10)" - spark.sql(sql1) - spark.sql(sql2) - }) + nativeWrite2 { + format => + val tbl = "t_" + format + val sql1 = + s"create table $tbl(a int, b map, c struct) " + + s"partitioned by (day string) stored as $format" + val sql2 = s"insert overwrite $tbl partition (day) " + + s"select id as a, str_to_map(concat('t1:','a','&t2:','b'),'&',':'), " + + s"struct('1', null) as c, '2024-01-08' as day from range(10)" + (tbl, sql1, sql2) } } - } diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala index 09fa3ff109f2..1b3df81667a0 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/metrics/GlutenClickHouseTPCHMetricsSuite.scala @@ -46,7 +46,7 @@ class GlutenClickHouseTPCHMetricsSuite extends GlutenClickHouseTPCHAbstractSuite .set("spark.io.compression.codec", "LZ4") .set("spark.sql.shuffle.partitions", "1") .set("spark.sql.autoBroadcastJoinThreshold", "10MB") - .set("spark.gluten.sql.columnar.backend.ch.runtime_config.logger.level", "DEBUG") + // .set("spark.gluten.sql.columnar.backend.ch.runtime_config.logger.level", "DEBUG") .set( "spark.gluten.sql.columnar.backend.ch.runtime_settings.input_format_parquet_max_block_size", s"$parquetMaxBlockSize") diff --git a/backends-clickhouse/src/test/scala/org/apache/spark/gluten/NativeWriteChecker.scala b/backends-clickhouse/src/test/scala/org/apache/spark/gluten/NativeWriteChecker.scala new file mode 100644 index 000000000000..79616d52d0bc --- /dev/null +++ b/backends-clickhouse/src/test/scala/org/apache/spark/gluten/NativeWriteChecker.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.gluten + +import org.apache.gluten.execution.GlutenClickHouseWholeStageTransformerSuite + +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.datasources.FakeRowAdaptor +import org.apache.spark.sql.util.QueryExecutionListener + +trait NativeWriteChecker extends GlutenClickHouseWholeStageTransformerSuite { + + def checkNativeWrite(sqlStr: String, checkNative: Boolean): Unit = { + var nativeUsed = false + + val queryListener = new QueryExecutionListener { + override def onFailure(f: String, qe: QueryExecution, e: Exception): Unit = {} + override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = { + if (!nativeUsed) { + nativeUsed = if (isSparkVersionGE("3.4")) { + false + } else { + qe.executedPlan.find(_.isInstanceOf[FakeRowAdaptor]).isDefined + } + } + } + } + + try { + spark.listenerManager.register(queryListener) + spark.sql(sqlStr) + spark.sparkContext.listenerBus.waitUntilEmpty() + assertResult(checkNative)(nativeUsed) + } finally { + spark.listenerManager.unregister(queryListener) + } + } +} diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala index 1f868c4c2044..7b8d523a6d27 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala @@ -827,15 +827,6 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { buf.result } - /** - * Generate extended columnar post-rules. - * - * @return - */ - override def genExtendedColumnarPostRules(): List[SparkSession => Rule[SparkPlan]] = { - SparkShimLoader.getSparkShims.getExtendedColumnarPostRules() ::: List() - } - override def genInjectPostHocResolutionRules(): List[SparkSession => Rule[LogicalPlan]] = { List(ArrowConvertorRule) } diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp b/cpp-ch/local-engine/Common/CHUtil.cpp index 94cd38003bad..ae3f6dbd5208 100644 --- a/cpp-ch/local-engine/Common/CHUtil.cpp +++ b/cpp-ch/local-engine/Common/CHUtil.cpp @@ -77,6 +77,7 @@ namespace ErrorCodes { extern const int BAD_ARGUMENTS; extern const int UNKNOWN_TYPE; +extern const int CANNOT_PARSE_PROTOBUF_SCHEMA; } } @@ -466,17 +467,17 @@ String QueryPipelineUtil::explainPipeline(DB::QueryPipeline & pipeline) using namespace DB; -std::map BackendInitializerUtil::getBackendConfMap(std::string * plan) +std::map BackendInitializerUtil::getBackendConfMap(const std::string & plan) { std::map ch_backend_conf; - if (plan == nullptr) + if (plan.empty()) return ch_backend_conf; /// Parse backend configs from plan extensions do { auto plan_ptr = std::make_unique(); - auto success = plan_ptr->ParseFromString(*plan); + auto success = plan_ptr->ParseFromString(plan); if (!success) break; @@ -841,14 +842,8 @@ void BackendInitializerUtil::initCompiledExpressionCache(DB::Context::Configurat #endif } -void BackendInitializerUtil::init_json(std::string * plan_json) -{ - auto plan_ptr = std::make_unique(); - google::protobuf::util::JsonStringToMessage(plan_json->c_str(), plan_ptr.get()); - return init(new String(plan_ptr->SerializeAsString())); -} -void BackendInitializerUtil::init(std::string * plan) +void BackendInitializerUtil::init(const std::string & plan) { std::map backend_conf_map = getBackendConfMap(plan); DB::Context::ConfigurationPtr config = initConfig(backend_conf_map); @@ -906,7 +901,7 @@ void BackendInitializerUtil::init(std::string * plan) }); } -void BackendInitializerUtil::updateConfig(const DB::ContextMutablePtr & context, std::string * plan) +void BackendInitializerUtil::updateConfig(const DB::ContextMutablePtr & context, const std::string & plan) { std::map backend_conf_map = getBackendConfMap(plan); diff --git a/cpp-ch/local-engine/Common/CHUtil.h b/cpp-ch/local-engine/Common/CHUtil.h index 94e0f0168e11..245d7b3d15c4 100644 --- a/cpp-ch/local-engine/Common/CHUtil.h +++ b/cpp-ch/local-engine/Common/CHUtil.h @@ -137,9 +137,8 @@ class BackendInitializerUtil /// Initialize two kinds of resources /// 1. global level resources like global_context/shared_context, notice that they can only be initialized once in process lifetime /// 2. session level resources like settings/configs, they can be initialized multiple times following the lifetime of executor/driver - static void init(std::string * plan); - static void init_json(std::string * plan_json); - static void updateConfig(const DB::ContextMutablePtr &, std::string *); + static void init(const std::string & plan); + static void updateConfig(const DB::ContextMutablePtr &, const std::string &); // use excel text parser @@ -196,7 +195,7 @@ class BackendInitializerUtil static void updateNewSettings(const DB::ContextMutablePtr &, const DB::Settings &); - static std::map getBackendConfMap(std::string * plan); + static std::map getBackendConfMap(const std::string & plan); inline static std::once_flag init_flag; inline static Poco::Logger * logger; @@ -283,10 +282,7 @@ class ConcurrentDeque return deq.empty(); } - std::deque unsafeGet() - { - return deq; - } + std::deque unsafeGet() { return deq; } private: std::deque deq; diff --git a/cpp-ch/local-engine/Parser/CHColumnToSparkRow.cpp b/cpp-ch/local-engine/Parser/CHColumnToSparkRow.cpp index 2b4eb824a5fd..5bb66e4b3f9d 100644 --- a/cpp-ch/local-engine/Parser/CHColumnToSparkRow.cpp +++ b/cpp-ch/local-engine/Parser/CHColumnToSparkRow.cpp @@ -453,7 +453,7 @@ std::unique_ptr CHColumnToSparkRow::convertCHColumnToSparkRow(cons if (!block.columns()) throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "A block with empty columns"); std::unique_ptr spark_row_info = std::make_unique(block, masks); - spark_row_info->setBufferAddress(reinterpret_cast(alloc(spark_row_info->getTotalBytes(), 64))); + spark_row_info->setBufferAddress(static_cast(alloc(spark_row_info->getTotalBytes(), 64))); // spark_row_info->setBufferAddress(alignedAlloc(spark_row_info->getTotalBytes(), 64)); memset(spark_row_info->getBufferAddress(), 0, spark_row_info->getTotalBytes()); for (auto col_idx = 0; col_idx < spark_row_info->getNumCols(); col_idx++) diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp index 70db692c8009..3115950cdf09 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.cpp @@ -87,14 +87,14 @@ namespace DB { namespace ErrorCodes { - extern const int LOGICAL_ERROR; - extern const int UNKNOWN_TYPE; - extern const int BAD_ARGUMENTS; - extern const int NO_SUCH_DATA_PART; - extern const int UNKNOWN_FUNCTION; - extern const int CANNOT_PARSE_PROTOBUF_SCHEMA; - extern const int ILLEGAL_TYPE_OF_ARGUMENT; - extern const int INVALID_JOIN_ON_EXPRESSION; +extern const int LOGICAL_ERROR; +extern const int UNKNOWN_TYPE; +extern const int BAD_ARGUMENTS; +extern const int NO_SUCH_DATA_PART; +extern const int UNKNOWN_FUNCTION; +extern const int CANNOT_PARSE_PROTOBUF_SCHEMA; +extern const int ILLEGAL_TYPE_OF_ARGUMENT; +extern const int INVALID_JOIN_ON_EXPRESSION; } } @@ -144,16 +144,13 @@ void SerializedPlanParser::parseExtensions( if (extension.has_extension_function()) { function_mapping.emplace( - std::to_string(extension.extension_function().function_anchor()), - extension.extension_function().name()); + std::to_string(extension.extension_function().function_anchor()), extension.extension_function().name()); } } } std::shared_ptr SerializedPlanParser::expressionsToActionsDAG( - const std::vector & expressions, - const Block & header, - const Block & read_schema) + const std::vector & expressions, const Block & header, const Block & read_schema) { auto actions_dag = std::make_shared(blockToNameAndTypeList(header)); NamesWithAliases required_columns; @@ -259,8 +256,8 @@ std::string getDecimalFunction(const substrait::Type_Decimal & decimal, bool nul bool SerializedPlanParser::isReadRelFromJava(const substrait::ReadRel & rel) { - return rel.has_local_files() && rel.local_files().items().size() == 1 && rel.local_files().items().at(0).uri_file().starts_with( - "iterator"); + return rel.has_local_files() && rel.local_files().items().size() == 1 + && rel.local_files().items().at(0).uri_file().starts_with("iterator"); } bool SerializedPlanParser::isReadFromMergeTree(const substrait::ReadRel & rel) @@ -380,13 +377,13 @@ DataTypePtr wrapNullableType(bool nullable, DataTypePtr nested_type) return nested_type; } -QueryPlanPtr SerializedPlanParser::parse(std::unique_ptr plan) +QueryPlanPtr SerializedPlanParser::parse(const substrait::Plan & plan) { - logDebugMessage(*plan, "substrait plan"); - parseExtensions(plan->extensions()); - if (plan->relations_size() == 1) + logDebugMessage(plan, "substrait plan"); + parseExtensions(plan.extensions()); + if (plan.relations_size() == 1) { - auto root_rel = plan->relations().at(0); + auto root_rel = plan.relations().at(0); if (!root_rel.has_root()) { throw Exception(ErrorCodes::BAD_ARGUMENTS, "must have root rel!"); @@ -587,9 +584,7 @@ SerializedPlanParser::getFunctionName(const std::string & function_signature, co { if (args.size() != 2) throw Exception( - ErrorCodes::BAD_ARGUMENTS, - "Spark function extract requires two args, function:{}", - function.ShortDebugString()); + ErrorCodes::BAD_ARGUMENTS, "Spark function extract requires two args, function:{}", function.ShortDebugString()); // Get the first arg: field const auto & extract_field = args.at(0); @@ -705,9 +700,7 @@ void SerializedPlanParser::parseArrayJoinArguments( /// The argument number of arrayJoin(converted from Spark explode/posexplode) should be 1 if (scalar_function.arguments_size() != 1) throw Exception( - ErrorCodes::BAD_ARGUMENTS, - "Argument number of arrayJoin should be 1 instead of {}", - scalar_function.arguments_size()); + ErrorCodes::BAD_ARGUMENTS, "Argument number of arrayJoin should be 1 instead of {}", scalar_function.arguments_size()); auto function_name_copy = function_name; parseFunctionArguments(actions_dag, parsed_args, function_name_copy, scalar_function); @@ -746,11 +739,7 @@ void SerializedPlanParser::parseArrayJoinArguments( } ActionsDAG::NodeRawConstPtrs SerializedPlanParser::parseArrayJoinWithDAG( - const substrait::Expression & rel, - std::vector & result_names, - ActionsDAGPtr actions_dag, - bool keep_result, - bool position) + const substrait::Expression & rel, std::vector & result_names, ActionsDAGPtr actions_dag, bool keep_result, bool position) { if (!rel.has_scalar_function()) throw Exception(ErrorCodes::BAD_ARGUMENTS, "The root of expression should be a scalar function:\n {}", rel.DebugString()); @@ -774,7 +763,8 @@ ActionsDAG::NodeRawConstPtrs SerializedPlanParser::parseArrayJoinWithDAG( auto tuple_element_builder = FunctionFactory::instance().get("sparkTupleElement", context); auto tuple_index_type = std::make_shared(); - auto add_tuple_element = [&](const ActionsDAG::Node * tuple_node, size_t i) -> const ActionsDAG::Node * { + auto add_tuple_element = [&](const ActionsDAG::Node * tuple_node, size_t i) -> const ActionsDAG::Node * + { ColumnWithTypeAndName index_col(tuple_index_type->createColumnConst(1, i), tuple_index_type, getUniqueName(std::to_string(i))); const auto * index_node = &actions_dag->addColumn(std::move(index_col)); auto result_name = "sparkTupleElement(" + tuple_node->result_name + ", " + index_node->result_name + ")"; @@ -866,10 +856,7 @@ ActionsDAG::NodeRawConstPtrs SerializedPlanParser::parseArrayJoinWithDAG( } const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( - const substrait::Expression & rel, - std::string & result_name, - ActionsDAGPtr actions_dag, - bool keep_result) + const substrait::Expression & rel, std::string & result_name, ActionsDAGPtr actions_dag, bool keep_result) { if (!rel.has_scalar_function()) throw Exception(ErrorCodes::BAD_ARGUMENTS, "the root of expression should be a scalar function:\n {}", rel.DebugString()); @@ -884,10 +871,7 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( if (auto func_parser = FunctionParserFactory::instance().tryGet(func_name, this)) { LOG_DEBUG( - &Poco::Logger::get("SerializedPlanParser"), - "parse function {} by function parser: {}", - func_name, - func_parser->getName()); + &Poco::Logger::get("SerializedPlanParser"), "parse function {} by function parser: {}", func_name, func_parser->getName()); const auto * result_node = func_parser->parse(scalar_function, actions_dag); if (keep_result) actions_dag->addOrReplaceInOutputs(*result_node); @@ -956,12 +940,10 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( UInt32 precision = rel.scalar_function().output_type().decimal().precision(); UInt32 scale = rel.scalar_function().output_type().decimal().scale(); auto uint32_type = std::make_shared(); - new_args.emplace_back( - &actions_dag->addColumn( - ColumnWithTypeAndName(uint32_type->createColumnConst(1, precision), uint32_type, getUniqueName(toString(precision))))); - new_args.emplace_back( - &actions_dag->addColumn( - ColumnWithTypeAndName(uint32_type->createColumnConst(1, scale), uint32_type, getUniqueName(toString(scale))))); + new_args.emplace_back(&actions_dag->addColumn( + ColumnWithTypeAndName(uint32_type->createColumnConst(1, precision), uint32_type, getUniqueName(toString(precision))))); + new_args.emplace_back(&actions_dag->addColumn( + ColumnWithTypeAndName(uint32_type->createColumnConst(1, scale), uint32_type, getUniqueName(toString(scale))))); args = std::move(new_args); } else if (startsWith(function_signature, "make_decimal:")) @@ -976,12 +958,10 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( UInt32 precision = rel.scalar_function().output_type().decimal().precision(); UInt32 scale = rel.scalar_function().output_type().decimal().scale(); auto uint32_type = std::make_shared(); - new_args.emplace_back( - &actions_dag->addColumn( - ColumnWithTypeAndName(uint32_type->createColumnConst(1, precision), uint32_type, getUniqueName(toString(precision))))); - new_args.emplace_back( - &actions_dag->addColumn( - ColumnWithTypeAndName(uint32_type->createColumnConst(1, scale), uint32_type, getUniqueName(toString(scale))))); + new_args.emplace_back(&actions_dag->addColumn( + ColumnWithTypeAndName(uint32_type->createColumnConst(1, precision), uint32_type, getUniqueName(toString(precision))))); + new_args.emplace_back(&actions_dag->addColumn( + ColumnWithTypeAndName(uint32_type->createColumnConst(1, scale), uint32_type, getUniqueName(toString(scale))))); args = std::move(new_args); } @@ -999,9 +979,8 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( actions_dag, function_node, // as stated in isTypeMatched, currently we don't change nullability of the result type - function_node->result_type->isNullable() - ? local_engine::wrapNullableType(true, result_type)->getName() - : local_engine::removeNullable(result_type)->getName(), + function_node->result_type->isNullable() ? local_engine::wrapNullableType(true, result_type)->getName() + : local_engine::removeNullable(result_type)->getName(), function_node->result_name, CastType::accurateOrNull); } @@ -1011,9 +990,8 @@ const ActionsDAG::Node * SerializedPlanParser::parseFunctionWithDAG( actions_dag, function_node, // as stated in isTypeMatched, currently we don't change nullability of the result type - function_node->result_type->isNullable() - ? local_engine::wrapNullableType(true, result_type)->getName() - : local_engine::removeNullable(result_type)->getName(), + function_node->result_type->isNullable() ? local_engine::wrapNullableType(true, result_type)->getName() + : local_engine::removeNullable(result_type)->getName(), function_node->result_name); } } @@ -1159,9 +1137,7 @@ void SerializedPlanParser::parseFunctionArgument( } const ActionsDAG::Node * SerializedPlanParser::parseFunctionArgument( - ActionsDAGPtr & actions_dag, - const std::string & function_name, - const substrait::FunctionArgument & arg) + ActionsDAGPtr & actions_dag, const std::string & function_name, const substrait::FunctionArgument & arg) { const ActionsDAG::Node * res; if (arg.value().has_scalar_function()) @@ -1189,11 +1165,8 @@ std::pair SerializedPlanParser::convertStructFieldType(const } auto type_id = type->getTypeId(); - if (type_id == TypeIndex::UInt8 || type_id == TypeIndex::UInt16 || type_id == TypeIndex::UInt32 - || type_id == TypeIndex::UInt64) - { + if (type_id == TypeIndex::UInt8 || type_id == TypeIndex::UInt16 || type_id == TypeIndex::UInt32 || type_id == TypeIndex::UInt64) return {type, field}; - } UINT_CONVERT(type, field, Int8) UINT_CONVERT(type, field, Int16) UINT_CONVERT(type, field, Int32) @@ -1203,11 +1176,7 @@ std::pair SerializedPlanParser::convertStructFieldType(const } ActionsDAGPtr SerializedPlanParser::parseFunction( - const Block & header, - const substrait::Expression & rel, - std::string & result_name, - ActionsDAGPtr actions_dag, - bool keep_result) + const Block & header, const substrait::Expression & rel, std::string & result_name, ActionsDAGPtr actions_dag, bool keep_result) { if (!actions_dag) actions_dag = std::make_shared(blockToNameAndTypeList(header)); @@ -1217,11 +1186,7 @@ ActionsDAGPtr SerializedPlanParser::parseFunction( } ActionsDAGPtr SerializedPlanParser::parseFunctionOrExpression( - const Block & header, - const substrait::Expression & rel, - std::string & result_name, - ActionsDAGPtr actions_dag, - bool keep_result) + const Block & header, const substrait::Expression & rel, std::string & result_name, ActionsDAGPtr actions_dag, bool keep_result) { if (!actions_dag) actions_dag = std::make_shared(blockToNameAndTypeList(header)); @@ -1303,7 +1268,8 @@ ActionsDAGPtr SerializedPlanParser::parseJsonTuple( = &actions_dag->addFunction(json_extract_builder, {json_expr_node, extract_expr_node}, json_extract_result_name); auto tuple_element_builder = FunctionFactory::instance().get("sparkTupleElement", context); auto tuple_index_type = std::make_shared(); - auto add_tuple_element = [&](const ActionsDAG::Node * tuple_node, size_t i) -> const ActionsDAG::Node * { + auto add_tuple_element = [&](const ActionsDAG::Node * tuple_node, size_t i) -> const ActionsDAG::Node * + { ColumnWithTypeAndName index_col(tuple_index_type->createColumnConst(1, i), tuple_index_type, getUniqueName(std::to_string(i))); const auto * index_node = &actions_dag->addColumn(std::move(index_col)); auto result_name = "sparkTupleElement(" + tuple_node->result_name + ", " + index_node->result_name + ")"; @@ -1528,9 +1494,7 @@ std::pair SerializedPlanParser::parseLiteral(const substrait } default: { throw Exception( - ErrorCodes::UNKNOWN_TYPE, - "Unsupported spark literal type {}", - magic_enum::enum_name(literal.literal_type_case())); + ErrorCodes::UNKNOWN_TYPE, "Unsupported spark literal type {}", magic_enum::enum_name(literal.literal_type_case())); } } return std::make_pair(std::move(type), std::move(field)); @@ -1732,8 +1696,7 @@ substrait::ReadRel::ExtensionTable SerializedPlanParser::parseExtensionTable(con { substrait::ReadRel::ExtensionTable extension_table; google::protobuf::io::CodedInputStream coded_in( - reinterpret_cast(split_info.data()), - static_cast(split_info.size())); + reinterpret_cast(split_info.data()), static_cast(split_info.size())); coded_in.SetRecursionLimit(100000); auto ok = extension_table.ParseFromCodedStream(&coded_in); @@ -1747,8 +1710,7 @@ substrait::ReadRel::LocalFiles SerializedPlanParser::parseLocalFiles(const std:: { substrait::ReadRel::LocalFiles local_files; google::protobuf::io::CodedInputStream coded_in( - reinterpret_cast(split_info.data()), - static_cast(split_info.size())); + reinterpret_cast(split_info.data()), static_cast(split_info.size())); coded_in.SetRecursionLimit(100000); auto ok = local_files.ParseFromCodedStream(&coded_in); @@ -1758,10 +1720,44 @@ substrait::ReadRel::LocalFiles SerializedPlanParser::parseLocalFiles(const std:: return local_files; } +std::unique_ptr SerializedPlanParser::createExecutor(DB::QueryPlanPtr query_plan) +{ + Stopwatch stopwatch; + auto * logger = &Poco::Logger::get("SerializedPlanParser"); + const Settings & settings = context->getSettingsRef(); + + QueryPriorities priorities; + auto query_status = std::make_shared( + context, + "", + context->getClientInfo(), + priorities.insert(static_cast(settings.priority)), + CurrentThread::getGroup(), + IAST::QueryKind::Select, + settings, + 0); + + QueryPlanOptimizationSettings optimization_settings{.optimize_plan = settings.query_plan_enable_optimizations}; + auto pipeline_builder = query_plan->buildQueryPipeline( + optimization_settings, + BuildQueryPipelineSettings{ + .actions_settings + = ExpressionActionsSettings{.can_compile_expressions = true, .min_count_to_compile_expression = 3, .compile_expressions = CompileExpressions::yes}, + .process_list_element = query_status}); + QueryPipeline pipeline = QueryPipelineBuilder::getPipeline(std::move(*pipeline_builder)); + LOG_INFO(logger, "build pipeline {} ms", stopwatch.elapsedMicroseconds() / 1000.0); + + LOG_DEBUG( + logger, "clickhouse plan [optimization={}]:\n{}", settings.query_plan_enable_optimizations, PlanUtil::explainPlan(*query_plan)); + LOG_DEBUG(logger, "clickhouse pipeline:\n{}", QueryPipelineUtil::explainPipeline(pipeline)); + + return std::make_unique( + context, std::move(query_plan), std::move(pipeline), query_plan->getCurrentDataStream().header.cloneEmpty()); +} -QueryPlanPtr SerializedPlanParser::parse(const std::string & plan) +QueryPlanPtr SerializedPlanParser::parse(const std::string_view & plan) { - auto plan_ptr = std::make_unique(); + substrait::Plan s_plan; /// https://stackoverflow.com/questions/52028583/getting-error-parsing-protobuf-data /// Parsing may fail when the number of recursive layers is large. /// Here, set a limit large enough to avoid this problem. @@ -1769,11 +1765,10 @@ QueryPlanPtr SerializedPlanParser::parse(const std::string & plan) google::protobuf::io::CodedInputStream coded_in(reinterpret_cast(plan.data()), static_cast(plan.size())); coded_in.SetRecursionLimit(100000); - auto ok = plan_ptr->ParseFromCodedStream(&coded_in); - if (!ok) + if (!s_plan.ParseFromCodedStream(&coded_in)) throw Exception(ErrorCodes::CANNOT_PARSE_PROTOBUF_SCHEMA, "Parse substrait::Plan from string failed"); - auto res = parse(std::move(plan_ptr)); + auto res = parse(s_plan); #ifndef NDEBUG PlanUtil::checkOuputType(*res); @@ -1788,17 +1783,16 @@ QueryPlanPtr SerializedPlanParser::parse(const std::string & plan) return res; } -QueryPlanPtr SerializedPlanParser::parseJson(const std::string & json_plan) +QueryPlanPtr SerializedPlanParser::parseJson(const std::string_view & json_plan) { - auto plan_ptr = std::make_unique(); - auto s = google::protobuf::util::JsonStringToMessage(absl::string_view(json_plan), plan_ptr.get()); + substrait::Plan plan; + auto s = google::protobuf::util::JsonStringToMessage(json_plan, &plan); if (!s.ok()) throw Exception(ErrorCodes::CANNOT_PARSE_PROTOBUF_SCHEMA, "Parse substrait::Plan from json string failed: {}", s.ToString()); - return parse(std::move(plan_ptr)); + return parse(plan); } -SerializedPlanParser::SerializedPlanParser(const ContextPtr & context_) - : context(context_) +SerializedPlanParser::SerializedPlanParser(const ContextPtr & context_) : context(context_) { } @@ -1807,13 +1801,10 @@ ContextMutablePtr SerializedPlanParser::global_context = nullptr; Context::ConfigurationPtr SerializedPlanParser::config = nullptr; void SerializedPlanParser::collectJoinKeys( - const substrait::Expression & condition, - std::vector> & join_keys, - int32_t right_key_start) + const substrait::Expression & condition, std::vector> & join_keys, int32_t right_key_start) { auto condition_name = getFunctionName( - function_mapping.at(std::to_string(condition.scalar_function().function_reference())), - condition.scalar_function()); + function_mapping.at(std::to_string(condition.scalar_function().function_reference())), condition.scalar_function()); if (condition_name == "and") { collectJoinKeys(condition.scalar_function().arguments(0).value(), join_keys, right_key_start); @@ -1863,8 +1854,8 @@ ASTPtr ASTParser::parseToAST(const Names & names, const substrait::Expression & auto substrait_name = function_signature.substr(0, function_signature.find(':')); auto func_parser = FunctionParserFactory::instance().tryGet(substrait_name, plan_parser); - String function_name = func_parser ? func_parser->getName() - : SerializedPlanParser::getFunctionName(function_signature, scalar_function); + String function_name + = func_parser ? func_parser->getName() : SerializedPlanParser::getFunctionName(function_signature, scalar_function); ASTs ast_args; parseFunctionArgumentsToAST(names, scalar_function, ast_args); @@ -1876,9 +1867,7 @@ ASTPtr ASTParser::parseToAST(const Names & names, const substrait::Expression & } void ASTParser::parseFunctionArgumentsToAST( - const Names & names, - const substrait::Expression_ScalarFunction & scalar_function, - ASTs & ast_args) + const Names & names, const substrait::Expression_ScalarFunction & scalar_function, ASTs & ast_args) { const auto & args = scalar_function.arguments(); @@ -2021,12 +2010,12 @@ ASTPtr ASTParser::parseArgumentToAST(const Names & names, const substrait::Expre } } -void SerializedPlanParser::removeNullableForRequiredColumns(const std::set & require_columns, ActionsDAGPtr actions_dag) +void SerializedPlanParser::removeNullableForRequiredColumns( + const std::set & require_columns, const ActionsDAGPtr & actions_dag) const { for (const auto & item : require_columns) { - const auto * require_node = actions_dag->tryFindInOutputs(item); - if (require_node) + if (const auto * require_node = actions_dag->tryFindInOutputs(item)) { auto function_builder = FunctionFactory::instance().get("assumeNotNull", context); ActionsDAG::NodeRawConstPtrs args = {require_node}; @@ -2037,9 +2026,7 @@ void SerializedPlanParser::removeNullableForRequiredColumns(const std::set & columns, - ActionsDAGPtr actions_dag, - std::map & nullable_measure_names) + const std::vector & columns, ActionsDAGPtr actions_dag, std::map & nullable_measure_names) { for (const auto & item : columns) { @@ -2092,86 +2079,23 @@ LocalExecutor::~LocalExecutor() } } - -void LocalExecutor::execute(QueryPlanPtr query_plan) -{ - Stopwatch stopwatch; - - const Settings & settings = context->getSettingsRef(); - current_query_plan = std::move(query_plan); - auto * logger = &Poco::Logger::get("LocalExecutor"); - - QueryPriorities priorities; - auto query_status = std::make_shared( - context, - "", - context->getClientInfo(), - priorities.insert(static_cast(settings.priority)), - CurrentThread::getGroup(), - IAST::QueryKind::Select, - settings, - 0); - - QueryPlanOptimizationSettings optimization_settings{.optimize_plan = settings.query_plan_enable_optimizations}; - auto pipeline_builder = current_query_plan->buildQueryPipeline( - optimization_settings, - BuildQueryPipelineSettings{ - .actions_settings - = ExpressionActionsSettings{.can_compile_expressions = true, .min_count_to_compile_expression = 3, - .compile_expressions = CompileExpressions::yes}, - .process_list_element = query_status}); - - LOG_DEBUG(logger, "clickhouse plan after optimization:\n{}", PlanUtil::explainPlan(*current_query_plan)); - query_pipeline = QueryPipelineBuilder::getPipeline(std::move(*pipeline_builder)); - LOG_DEBUG(logger, "clickhouse pipeline:\n{}", QueryPipelineUtil::explainPipeline(query_pipeline)); - auto t_pipeline = stopwatch.elapsedMicroseconds(); - - executor = std::make_unique(query_pipeline); - auto t_executor = stopwatch.elapsedMicroseconds() - t_pipeline; - stopwatch.stop(); - LOG_INFO( - logger, - "build pipeline {} ms; create executor {} ms;", - t_pipeline / 1000.0, - t_executor / 1000.0); - - header = current_query_plan->getCurrentDataStream().header.cloneEmpty(); - ch_column_to_spark_row = std::make_unique(); -} - -std::unique_ptr LocalExecutor::writeBlockToSparkRow(Block & block) +std::unique_ptr LocalExecutor::writeBlockToSparkRow(const Block & block) const { return ch_column_to_spark_row->convertCHColumnToSparkRow(block); } bool LocalExecutor::hasNext() { - bool has_next; - try + size_t columns = currentBlock().columns(); + if (columns == 0 || isConsumed()) { - size_t columns = currentBlock().columns(); - if (columns == 0 || isConsumed()) - { - auto empty_block = header.cloneEmpty(); - setCurrentBlock(empty_block); - has_next = executor->pull(currentBlock()); - produce(); - } - else - { - has_next = true; - } - } - catch (Exception & e) - { - LOG_ERROR( - &Poco::Logger::get("LocalExecutor"), - "LocalExecutor run query plan failed with message: {}. Plan Explained: \n{}", - e.message(), - PlanUtil::explainPlan(*current_query_plan)); - throw; + auto empty_block = header.cloneEmpty(); + setCurrentBlock(empty_block); + bool has_next = executor->pull(currentBlock()); + produce(); + return has_next; } - return has_next; + return true; } SparkRowInfoPtr LocalExecutor::next() @@ -2246,12 +2170,17 @@ Block & LocalExecutor::getHeader() return header; } -LocalExecutor::LocalExecutor(ContextPtr context_) - : context(context_) +LocalExecutor::LocalExecutor(const ContextPtr & context_, QueryPlanPtr query_plan, QueryPipeline && pipeline, const Block & header_) + : query_pipeline(std::move(pipeline)) + , executor(std::make_unique(query_pipeline)) + , header(header_) + , context(context_) + , ch_column_to_spark_row(std::make_unique()) + , current_query_plan(std::move(query_plan)) { } -std::string LocalExecutor::dumpPipeline() +std::string LocalExecutor::dumpPipeline() const { const auto & processors = query_pipeline.getProcessors(); for (auto & processor : processors) @@ -2275,12 +2204,8 @@ std::string LocalExecutor::dumpPipeline() } NonNullableColumnsResolver::NonNullableColumnsResolver( - const Block & header_, - SerializedPlanParser & parser_, - const substrait::Expression & cond_rel_) - : header(header_) - , parser(parser_) - , cond_rel(cond_rel_) + const Block & header_, SerializedPlanParser & parser_, const substrait::Expression & cond_rel_) + : header(header_), parser(parser_), cond_rel(cond_rel_) { } @@ -2352,8 +2277,7 @@ void NonNullableColumnsResolver::visitNonNullable(const substrait::Expression & } std::string NonNullableColumnsResolver::safeGetFunctionName( - const std::string & function_signature, - const substrait::Expression_ScalarFunction & function) + const std::string & function_signature, const substrait::Expression_ScalarFunction & function) const { try { diff --git a/cpp-ch/local-engine/Parser/SerializedPlanParser.h b/cpp-ch/local-engine/Parser/SerializedPlanParser.h index 71cdca58a6ce..82e8c4077841 100644 --- a/cpp-ch/local-engine/Parser/SerializedPlanParser.h +++ b/cpp-ch/local-engine/Parser/SerializedPlanParser.h @@ -218,6 +218,7 @@ DataTypePtr wrapNullableType(bool nullable, DataTypePtr nested_type); std::string join(const ActionsDAG::NodeRawConstPtrs & v, char c); class SerializedPlanParser; +class LocalExecutor; // Give a condition expression `cond_rel_`, found all columns with nullability that must not containt // null after this filter. @@ -241,7 +242,7 @@ class NonNullableColumnsResolver void visit(const substrait::Expression & expr); void visitNonNullable(const substrait::Expression & expr); - String safeGetFunctionName(const String & function_signature, const substrait::Expression_ScalarFunction & function); + String safeGetFunctionName(const String & function_signature, const substrait::Expression_ScalarFunction & function) const; }; class SerializedPlanParser @@ -257,11 +258,21 @@ class SerializedPlanParser friend class JoinRelParser; friend class MergeTreeRelParser; + std::unique_ptr createExecutor(DB::QueryPlanPtr query_plan); + + DB::QueryPlanPtr parse(const std::string_view & plan); + DB::QueryPlanPtr parse(const substrait::Plan & plan); + public: explicit SerializedPlanParser(const ContextPtr & context); - DB::QueryPlanPtr parse(const std::string & plan); - DB::QueryPlanPtr parseJson(const std::string & json_plan); - DB::QueryPlanPtr parse(std::unique_ptr plan); + + /// UT only + DB::QueryPlanPtr parseJson(const std::string_view & json_plan); + std::unique_ptr createExecutor(const substrait::Plan & plan) { return createExecutor(parse((plan))); } + /// + + template + std::unique_ptr createExecutor(const std::string_view & plan); DB::QueryPlanStepPtr parseReadRealWithLocalFile(const substrait::ReadRel & rel); DB::QueryPlanStepPtr parseReadRealWithJavaIter(const substrait::ReadRel & rel); @@ -372,7 +383,7 @@ class SerializedPlanParser const ActionsDAG::Node * toFunctionNode(ActionsDAGPtr actions_dag, const String & function, const DB::ActionsDAG::NodeRawConstPtrs & args); // remove nullable after isNotNull - void removeNullableForRequiredColumns(const std::set & require_columns, ActionsDAGPtr actions_dag); + void removeNullableForRequiredColumns(const std::set & require_columns, const ActionsDAGPtr & actions_dag) const; std::string getUniqueName(const std::string & name) { return name + "_" + std::to_string(name_no++); } static std::pair parseLiteral(const substrait::Expression_Literal & literal); void wrapNullable( @@ -394,6 +405,12 @@ class SerializedPlanParser const ActionsDAG::Node * addColumn(DB::ActionsDAGPtr actions_dag, const DataTypePtr & type, const Field & field); }; +template +std::unique_ptr SerializedPlanParser::createExecutor(const std::string_view & plan) +{ + return createExecutor(JsonPlan ? parseJson(plan) : parse(plan)); +} + struct SparkBuffer { char * address; @@ -403,16 +420,14 @@ struct SparkBuffer class LocalExecutor : public BlockIterator { public: - LocalExecutor() = default; - explicit LocalExecutor(ContextPtr context); + LocalExecutor(const ContextPtr & context_, QueryPlanPtr query_plan, QueryPipeline && pipeline, const Block & header_); ~LocalExecutor(); - void execute(QueryPlanPtr query_plan); SparkRowInfoPtr next(); Block * nextColumnar(); bool hasNext(); - /// Stop execution and wait for pipeline exit, used when task receives shutdown command or executor receives SIGTERM signal + /// Stop execution, used when task receives shutdown command or executor receives SIGTERM signal void cancel(); Block & getHeader(); @@ -425,13 +440,13 @@ class LocalExecutor : public BlockIterator static void removeExecutor(Int64 handle); private: - std::unique_ptr writeBlockToSparkRow(DB::Block & block); + std::unique_ptr writeBlockToSparkRow(const DB::Block & block) const; void asyncCancel(); void waitCancelFinished(); /// Dump processor runtime information to log - std::string dumpPipeline(); + std::string dumpPipeline() const; QueryPipeline query_pipeline; std::unique_ptr executor; @@ -439,7 +454,7 @@ class LocalExecutor : public BlockIterator ContextPtr context; std::unique_ptr ch_column_to_spark_row; std::unique_ptr spark_buffer; - DB::QueryPlanPtr current_query_plan; + QueryPlanPtr current_query_plan; RelMetricPtr metric; std::vector extra_plan_holder; std::atomic is_cancelled{false}; diff --git a/cpp-ch/local-engine/local_engine_jni.cpp b/cpp-ch/local-engine/local_engine_jni.cpp index bbc467879182..9c642d70ec27 100644 --- a/cpp-ch/local-engine/local_engine_jni.cpp +++ b/cpp-ch/local-engine/local_engine_jni.cpp @@ -224,11 +224,9 @@ JNIEXPORT void JNI_OnUnload(JavaVM * vm, void * /*reserved*/) JNIEXPORT void Java_org_apache_gluten_vectorized_ExpressionEvaluatorJniWrapper_nativeInitNative(JNIEnv * env, jobject, jbyteArray conf_plan) { LOCAL_ENGINE_JNI_METHOD_START - jsize plan_buf_size = env->GetArrayLength(conf_plan); + std::string::size_type plan_buf_size = env->GetArrayLength(conf_plan); jbyte * plan_buf_addr = env->GetByteArrayElements(conf_plan, nullptr); - std::string plan_str; - plan_str.assign(reinterpret_cast(plan_buf_addr), plan_buf_size); - local_engine::BackendInitializerUtil::init(&plan_str); + local_engine::BackendInitializerUtil::init({reinterpret_cast(plan_buf_addr), plan_buf_size}); env->ReleaseByteArrayElements(conf_plan, plan_buf_addr, JNI_ABORT); LOCAL_ENGINE_JNI_METHOD_END(env, ) } @@ -254,11 +252,9 @@ JNIEXPORT jlong Java_org_apache_gluten_vectorized_ExpressionEvaluatorJniWrapper_ auto query_context = local_engine::getAllocator(allocator_id)->query_context; // by task update new configs ( in case of dynamic config update ) - jsize plan_buf_size = env->GetArrayLength(conf_plan); + std::string::size_type plan_buf_size = env->GetArrayLength(conf_plan); jbyte * plan_buf_addr = env->GetByteArrayElements(conf_plan, nullptr); - std::string plan_str; - plan_str.assign(reinterpret_cast(plan_buf_addr), plan_buf_size); - local_engine::BackendInitializerUtil::updateConfig(query_context, &plan_str); + local_engine::BackendInitializerUtil::updateConfig(query_context, {reinterpret_cast(plan_buf_addr), plan_buf_size}); local_engine::SerializedPlanParser parser(query_context); jsize iter_num = env->GetArrayLength(iter_arr); @@ -277,17 +273,14 @@ JNIEXPORT jlong Java_org_apache_gluten_vectorized_ExpressionEvaluatorJniWrapper_ parser.addSplitInfo(std::string{reinterpret_cast(split_info_addr), split_info_size}); } - jsize plan_size = env->GetArrayLength(plan); + std::string::size_type plan_size = env->GetArrayLength(plan); jbyte * plan_address = env->GetByteArrayElements(plan, nullptr); - std::string plan_string; - plan_string.assign(reinterpret_cast(plan_address), plan_size); - auto query_plan = parser.parse(plan_string); - local_engine::LocalExecutor * executor = new local_engine::LocalExecutor(query_context); + local_engine::LocalExecutor * executor + = parser.createExecutor({reinterpret_cast(plan_address), plan_size}).release(); local_engine::LocalExecutor::addExecutor(executor); - LOG_INFO(&Poco::Logger::get("jni"), "Construct LocalExecutor {}", reinterpret_cast(executor)); + LOG_INFO(&Poco::Logger::get("jni"), "Construct LocalExecutor {}", reinterpret_cast(executor)); executor->setMetric(parser.getMetric()); executor->setExtraPlanHolder(parser.extra_plan_holder); - executor->execute(std::move(query_plan)); env->ReleaseByteArrayElements(plan, plan_address, JNI_ABORT); env->ReleaseByteArrayElements(conf_plan, plan_buf_addr, JNI_ABORT); return reinterpret_cast(executor); @@ -932,11 +925,10 @@ JNIEXPORT jlong Java_org_apache_spark_sql_execution_datasources_CHDatasourceJniW LOCAL_ENGINE_JNI_METHOD_START auto query_context = local_engine::getAllocator(allocator_id)->query_context; // by task update new configs ( in case of dynamic config update ) - jsize conf_plan_buf_size = env->GetArrayLength(conf_plan); + std::string::size_type conf_plan_buf_size = env->GetArrayLength(conf_plan); jbyte * conf_plan_buf_addr = env->GetByteArrayElements(conf_plan, nullptr); - std::string conf_plan_str; - conf_plan_str.assign(reinterpret_cast(conf_plan_buf_addr), conf_plan_buf_size); - local_engine::BackendInitializerUtil::updateConfig(query_context, &conf_plan_str); + local_engine::BackendInitializerUtil::updateConfig( + query_context, {reinterpret_cast(conf_plan_buf_addr), conf_plan_buf_size}); const auto uuid_str = jstring2string(env, uuid_); const auto task_id = jstring2string(env, task_id_); @@ -1329,14 +1321,11 @@ Java_org_apache_gluten_vectorized_SimpleExpressionEval_createNativeInstance(JNIE local_engine::SerializedPlanParser parser(context); jobject iter = env->NewGlobalRef(input); parser.addInputIter(iter, false); - jsize plan_size = env->GetArrayLength(plan); + std::string::size_type plan_size = env->GetArrayLength(plan); jbyte * plan_address = env->GetByteArrayElements(plan, nullptr); - std::string plan_string; - plan_string.assign(reinterpret_cast(plan_address), plan_size); - auto query_plan = parser.parse(plan_string); - local_engine::LocalExecutor * executor = new local_engine::LocalExecutor(context); + local_engine::LocalExecutor * executor + = parser.createExecutor({reinterpret_cast(plan_address), plan_size}).release(); local_engine::LocalExecutor::addExecutor(executor); - executor->execute(std::move(query_plan)); env->ReleaseByteArrayElements(plan, plan_address, JNI_ABORT); return reinterpret_cast(executor); LOCAL_ENGINE_JNI_METHOD_END(env, -1) diff --git a/cpp-ch/local-engine/tests/benchmark_local_engine.cpp b/cpp-ch/local-engine/tests/benchmark_local_engine.cpp index 89fa4fa961ea..208a3b518d45 100644 --- a/cpp-ch/local-engine/tests/benchmark_local_engine.cpp +++ b/cpp-ch/local-engine/tests/benchmark_local_engine.cpp @@ -154,14 +154,11 @@ DB::ContextMutablePtr global_context; std::move(schema)) .build(); local_engine::SerializedPlanParser parser(global_context); - auto query_plan = parser.parse(std::move(plan)); - local_engine::LocalExecutor local_executor; + auto local_executor = parser.createExecutor(*plan); state.ResumeTiming(); - local_executor.execute(std::move(query_plan)); - while (local_executor.hasNext()) - { - local_engine::SparkRowInfoPtr spark_row_info = local_executor.next(); - } + + while (local_executor->hasNext()) + local_engine::SparkRowInfoPtr spark_row_info = local_executor->next(); } } @@ -212,13 +209,12 @@ DB::ContextMutablePtr global_context; std::move(schema)) .build(); local_engine::SerializedPlanParser parser(SerializedPlanParser::global_context); - auto query_plan = parser.parse(std::move(plan)); - local_engine::LocalExecutor local_executor; + auto local_executor = parser.createExecutor(*plan); state.ResumeTiming(); - local_executor.execute(std::move(query_plan)); - while (local_executor.hasNext()) + + while (local_executor->hasNext()) { - Block * block = local_executor.nextColumnar(); + Block * block = local_executor->nextColumnar(); delete block; } } @@ -238,15 +234,10 @@ DB::ContextMutablePtr global_context; std::ifstream t(path); std::string str((std::istreambuf_iterator(t)), std::istreambuf_iterator()); std::cout << "the plan from: " << path << std::endl; - - auto query_plan = parser.parse(str); - local_engine::LocalExecutor local_executor; + auto local_executor = parser.createExecutor(str); state.ResumeTiming(); - local_executor.execute(std::move(query_plan)); - while (local_executor.hasNext()) - { - [[maybe_unused]] auto * x = local_executor.nextColumnar(); - } + while (local_executor->hasNext()) [[maybe_unused]] + auto * x = local_executor->nextColumnar(); } } @@ -282,14 +273,12 @@ DB::ContextMutablePtr global_context; std::move(schema)) .build(); local_engine::SerializedPlanParser parser(SerializedPlanParser::global_context); - auto query_plan = parser.parse(std::move(plan)); - local_engine::LocalExecutor local_executor; + + auto local_executor = parser.createExecutor(*plan); state.ResumeTiming(); - local_executor.execute(std::move(query_plan)); - while (local_executor.hasNext()) - { - local_engine::SparkRowInfoPtr spark_row_info = local_executor.next(); - } + + while (local_executor->hasNext()) + local_engine::SparkRowInfoPtr spark_row_info = local_executor->next(); } } @@ -320,16 +309,13 @@ DB::ContextMutablePtr global_context; .build(); local_engine::SerializedPlanParser parser(SerializedPlanParser::global_context); - auto query_plan = parser.parse(std::move(plan)); - local_engine::LocalExecutor local_executor; - - local_executor.execute(std::move(query_plan)); + auto local_executor = parser.createExecutor(*plan); local_engine::SparkRowToCHColumn converter; - while (local_executor.hasNext()) + while (local_executor->hasNext()) { - local_engine::SparkRowInfoPtr spark_row_info = local_executor.next(); + local_engine::SparkRowInfoPtr spark_row_info = local_executor->next(); state.ResumeTiming(); - auto block = converter.convertSparkRowInfoToCHColumn(*spark_row_info, local_executor.getHeader()); + auto block = converter.convertSparkRowInfoToCHColumn(*spark_row_info, local_executor->getHeader()); state.PauseTiming(); } state.ResumeTiming(); @@ -368,16 +354,13 @@ DB::ContextMutablePtr global_context; std::move(schema)) .build(); local_engine::SerializedPlanParser parser(SerializedPlanParser::global_context); - auto query_plan = parser.parse(std::move(plan)); - local_engine::LocalExecutor local_executor; - - local_executor.execute(std::move(query_plan)); + auto local_executor = parser.createExecutor(*plan); local_engine::SparkRowToCHColumn converter; - while (local_executor.hasNext()) + while (local_executor->hasNext()) { - local_engine::SparkRowInfoPtr spark_row_info = local_executor.next(); + local_engine::SparkRowInfoPtr spark_row_info = local_executor->next(); state.ResumeTiming(); - auto block = converter.convertSparkRowInfoToCHColumn(*spark_row_info, local_executor.getHeader()); + auto block = converter.convertSparkRowInfoToCHColumn(*spark_row_info, local_executor->getHeader()); state.PauseTiming(); } state.ResumeTiming(); @@ -485,12 +468,8 @@ DB::ContextMutablePtr global_context; y.reserve(cnt); for (auto _ : state) - { for (i = 0; i < cnt; i++) - { y[i] = add(x[i], i); - } - } } [[maybe_unused]] static void BM_TestSumInline(benchmark::State & state) @@ -504,12 +483,8 @@ DB::ContextMutablePtr global_context; y.reserve(cnt); for (auto _ : state) - { for (i = 0; i < cnt; i++) - { y[i] = x[i] + i; - } - } } [[maybe_unused]] static void BM_TestPlus(benchmark::State & state) @@ -545,9 +520,7 @@ DB::ContextMutablePtr global_context; block.insert(y); auto executable_function = function->prepare(arguments); for (auto _ : state) - { auto result = executable_function->execute(block.getColumnsWithTypeAndName(), type, rows, false); - } } [[maybe_unused]] static void BM_TestPlusEmbedded(benchmark::State & state) @@ -847,9 +820,7 @@ QueryPlanPtr joinPlan(QueryPlanPtr left, QueryPlanPtr right, String left_key, St ASTPtr rkey = std::make_shared(right_key); join->addOnKeys(lkey, rkey, true); for (const auto & column : join->columnsFromJoinedTable()) - { join->addJoinedColumn(column); - } auto left_keys = left->getCurrentDataStream().header.getNamesAndTypesList(); join->addJoinedColumnsAndCorrectTypes(left_keys, true); @@ -920,7 +891,8 @@ BENCHMARK(BM_ParquetRead)->Unit(benchmark::kMillisecond)->Iterations(10); int main(int argc, char ** argv) { - BackendInitializerUtil::init(nullptr); + std::string empty; + BackendInitializerUtil::init(empty); SCOPE_EXIT({ BackendFinalizerUtil::finalizeGlobally(); }); ::benchmark::Initialize(&argc, argv); diff --git a/cpp-ch/local-engine/tests/gluten_test_util.h b/cpp-ch/local-engine/tests/gluten_test_util.h index d4c16e9fbbd8..dba4496d6221 100644 --- a/cpp-ch/local-engine/tests/gluten_test_util.h +++ b/cpp-ch/local-engine/tests/gluten_test_util.h @@ -24,6 +24,7 @@ #include #include #include +#include #include using BlockRowType = DB::ColumnsWithTypeAndName; @@ -60,6 +61,23 @@ AnotherRowType readParquetSchema(const std::string & file); DB::ActionsDAGPtr parseFilter(const std::string & filter, const AnotherRowType & name_and_types); +namespace pb_util +{ +template +std::string JsonStringToBinary(const std::string_view & json) +{ + Message message; + std::string binary; + auto s = google::protobuf::util::JsonStringToMessage(json, &message); + if (!s.ok()) + { + const std::string err_msg{s.message()}; + throw std::runtime_error(err_msg); + } + message.SerializeToString(&binary); + return binary; +} +} } inline DB::DataTypePtr BIGINT() diff --git a/cpp-ch/local-engine/tests/gtest_local_engine.cpp b/cpp-ch/local-engine/tests/gtest_local_engine.cpp index 2d1807841041..962bf9def52e 100644 --- a/cpp-ch/local-engine/tests/gtest_local_engine.cpp +++ b/cpp-ch/local-engine/tests/gtest_local_engine.cpp @@ -16,9 +16,12 @@ */ #include #include +#include +#include + #include -#include #include +#include #include #include #include @@ -28,7 +31,6 @@ #include #include #include -#include #include #include #include @@ -84,13 +86,23 @@ TEST(ReadBufferFromFile, seekBackwards) ASSERT_EQ(x, 8); } +INCBIN(resource_embedded_config_json, SOURCE_DIR "/utils/extern-local-engine/tests/json/gtest_local_engine_config.json"); + +namespace DB +{ +void registerOutputFormatParquet(DB::FormatFactory & factory); +} + int main(int argc, char ** argv) { - auto * init = new String("{\"advancedExtensions\":{\"enhancement\":{\"@type\":\"type.googleapis.com/substrait.Expression\",\"literal\":{\"map\":{\"keyValues\":[{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.ch.runtime_config.logger.level\"},\"value\":{\"string\":\"trace\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.ch.runtime_settings.max_bytes_before_external_sort\"},\"value\":{\"string\":\"5368709120\"}},{\"key\":{\"string\":\"spark.hadoop.fs.s3a.endpoint\"},\"value\":{\"string\":\"localhost:9000\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.velox.IOThreads\"},\"value\":{\"string\":\"0\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.ch.runtime_config.hdfs.input_read_timeout\"},\"value\":{\"string\":\"180000\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.ch.runtime_settings.query_plan_enable_optimizations\"},\"value\":{\"string\":\"false\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.ch.worker.id\"},\"value\":{\"string\":\"1\"}},{\"key\":{\"string\":\"spark.memory.offHeap.enabled\"},\"value\":{\"string\":\"true\"}},{\"key\":{\"string\":\"spark.hadoop.fs.s3a.iam.role.session.name\"},\"value\":{\"string\":\"\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.ch.runtime_config.hdfs.input_connect_timeout\"},\"value\":{\"string\":\"180000\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.shuffle.codec\"},\"value\":{\"string\":\"\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.ch.runtime_config.local_engine.settings.log_processors_profiles\"},\"value\":{\"string\":\"true\"}},{\"key\":{\"string\":\"spark.gluten.memory.offHeap.size.in.bytes\"},\"value\":{\"string\":\"10737418240\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.shuffle.codecBackend\"},\"value\":{\"string\":\"\"}},{\"key\":{\"string\":\"spark.sql.orc.compression.codec\"},\"value\":{\"string\":\"snappy\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.ch.runtime_settings.max_bytes_before_external_group_by\"},\"value\":{\"string\":\"5368709120\"}},{\"key\":{\"string\":\"spark.hadoop.input.write.timeout\"},\"value\":{\"string\":\"180000\"}},{\"key\":{\"string\":\"spark.hadoop.fs.s3a.secret.key\"},\"value\":{\"string\":\"\"}},{\"key\":{\"string\":\"spark.hadoop.fs.s3a.access.key\"},\"value\":{\"string\":\"\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.ch.runtime_config.hdfs.dfs_client_log_severity\"},\"value\":{\"string\":\"INFO\"}},{\"key\":{\"string\":\"spark.hadoop.fs.s3a.path.style.access\"},\"value\":{\"string\":\"true\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.ch.runtime_config.timezone\"},\"value\":{\"string\":\"Asia/Shanghai\"}},{\"key\":{\"string\":\"spark.hadoop.input.read.timeout\"},\"value\":{\"string\":\"180000\"}},{\"key\":{\"string\":\"spark.hadoop.fs.s3a.use.instance.credentials\"},\"value\":{\"string\":\"false\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.ch.runtime_settings.output_format_orc_compression_method\"},\"value\":{\"string\":\"snappy\"}},{\"key\":{\"string\":\"spark.hadoop.fs.s3a.iam.role\"},\"value\":{\"string\":\"\"}},{\"key\":{\"string\":\"spark.gluten.memory.task.offHeap.size.in.bytes\"},\"value\":{\"string\":\"10737418240\"}},{\"key\":{\"string\":\"spark.hadoop.input.connect.timeout\"},\"value\":{\"string\":\"180000\"}},{\"key\":{\"string\":\"spark.hadoop.dfs.client.log.severity\"},\"value\":{\"string\":\"INFO\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.velox.SplitPreloadPerDriver\"},\"value\":{\"string\":\"2\"}},{\"key\":{\"string\":\"spark.gluten.sql.columnar.backend.ch.runtime_config.hdfs.input_write_timeout\"},\"value\":{\"string\":\"180000\"}},{\"key\":{\"string\":\"spark.hadoop.fs.s3a.connection.ssl.enabled\"},\"value\":{\"string\":\"false\"}}]}}}}}"); + BackendInitializerUtil::init(test::pb_util::JsonStringToBinary( + {reinterpret_cast(gresource_embedded_config_jsonData), gresource_embedded_config_jsonSize})); + + auto & factory = FormatFactory::instance(); + DB::registerOutputFormatParquet(factory); - BackendInitializerUtil::init_json(std::move(init)); SCOPE_EXIT({ BackendFinalizerUtil::finalizeGlobally(); }); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); -} +} \ No newline at end of file diff --git a/cpp-ch/local-engine/tests/gtest_parser.cpp b/cpp-ch/local-engine/tests/gtest_parser.cpp index cbe41c90c81a..485740191ea3 100644 --- a/cpp-ch/local-engine/tests/gtest_parser.cpp +++ b/cpp-ch/local-engine/tests/gtest_parser.cpp @@ -14,307 +14,140 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include +#include #include -#include #include + using namespace local_engine; using namespace DB; -std::string splitBinaryFromJson(const std::string & json) +// Plan for https://github.com/ClickHouse/ClickHouse/pull/65234 +INCBIN(resource_embedded_pr_65234_json, SOURCE_DIR "/utils/extern-local-engine/tests/json/clickhouse_pr_65234.json"); + +TEST(SerializedPlanParser, PR65234) { - std::string binary; - substrait::ReadRel::LocalFiles local_files; - auto s = google::protobuf::util::JsonStringToMessage(absl::string_view(json), &local_files); - local_files.SerializeToString(&binary); - return binary; + const std::string split + = R"({"items":[{"uriFile":"file:///home/chang/SourceCode/rebase_gluten/backends-clickhouse/target/scala-2.12/test-classes/tests-working-home/tpch-data/supplier/part-00000-16caa751-9774-470c-bd37-5c84c53373c8-c000.snappy.parquet","length":"84633","parquet":{},"schema":{},"metadataColumns":[{}]}]})"; + SerializedPlanParser parser(SerializedPlanParser::global_context); + parser.addSplitInfo(test::pb_util::JsonStringToBinary(split)); + auto query_plan + = parser.parseJson({reinterpret_cast(gresource_embedded_pr_65234_jsonData), gresource_embedded_pr_65234_jsonSize}); } -std::string JsonPlanFor65234() +#include +#include +#include +#include +#include + +Chunk testChunk() { - // Plan for https://github.com/ClickHouse/ClickHouse/pull/65234 - return R"( + auto nameCol = STRING()->createColumn(); + nameCol->insert("one"); + nameCol->insert("two"); + nameCol->insert("three"); + + auto valueCol = UINT()->createColumn(); + valueCol->insert(1); + valueCol->insert(2); + valueCol->insert(3); + MutableColumns x; + x.push_back(std::move(nameCol)); + x.push_back(std::move(valueCol)); + return {std::move(x), 3}; +} + +TEST(LocalExecutor, StorageObjectStorageSink) { - "extensions": [{ - "extensionFunction": { - "functionAnchor": 1, - "name": "is_not_null:str" - } - }, { - "extensionFunction": { - "functionAnchor": 2, - "name": "equal:str_str" - } - }, { - "extensionFunction": { - "functionAnchor": 3, - "name": "is_not_null:i64" - } - }, { - "extensionFunction": { - "name": "and:bool_bool" - } - }], - "relations": [{ - "root": { - "input": { - "project": { - "common": { - "emit": { - "outputMapping": [2] - } - }, - "input": { - "filter": { - "common": { - "direct": { - } - }, - "input": { - "read": { - "common": { - "direct": { - } - }, - "baseSchema": { - "names": ["r_regionkey", "r_name"], - "struct": { - "types": [{ - "i64": { - "nullability": "NULLABILITY_NULLABLE" - } - }, { - "string": { - "nullability": "NULLABILITY_NULLABLE" - } - }] - }, - "columnTypes": ["NORMAL_COL", "NORMAL_COL"] - }, - "filter": { - "scalarFunction": { - "outputType": { - "bool": { - "nullability": "NULLABILITY_NULLABLE" - } - }, - "arguments": [{ - "value": { - "scalarFunction": { - "outputType": { - "bool": { - "nullability": "NULLABILITY_NULLABLE" - } - }, - "arguments": [{ - "value": { - "scalarFunction": { - "functionReference": 1, - "outputType": { - "bool": { - "nullability": "NULLABILITY_REQUIRED" - } - }, - "arguments": [{ - "value": { - "selection": { - "directReference": { - "structField": { - "field": 1 - } - } - } - } - }] - } - } - }, { - "value": { - "scalarFunction": { - "functionReference": 2, - "outputType": { - "bool": { - "nullability": "NULLABILITY_NULLABLE" - } - }, - "arguments": [{ - "value": { - "selection": { - "directReference": { - "structField": { - "field": 1 - } - } - } - } - }, { - "value": { - "literal": { - "string": "EUROPE" - } - } - }] - } - } - }] - } - } - }, { - "value": { - "scalarFunction": { - "functionReference": 3, - "outputType": { - "bool": { - "nullability": "NULLABILITY_REQUIRED" - } - }, - "arguments": [{ - "value": { - "selection": { - "directReference": { - "structField": { - } - } - } - } - }] - } - } - }] - } - }, - "advancedExtension": { - "optimization": { - "@type": "type.googleapis.com/google.protobuf.StringValue", - "value": "isMergeTree\u003d0\n" - } - } - } - }, - "condition": { - "scalarFunction": { - "outputType": { - "bool": { - "nullability": "NULLABILITY_NULLABLE" - } - }, - "arguments": [{ - "value": { - "scalarFunction": { - "outputType": { - "bool": { - "nullability": "NULLABILITY_NULLABLE" - } - }, - "arguments": [{ - "value": { - "scalarFunction": { - "functionReference": 1, - "outputType": { - "bool": { - "nullability": "NULLABILITY_REQUIRED" - } - }, - "arguments": [{ - "value": { - "selection": { - "directReference": { - "structField": { - "field": 1 - } - } - } - } - }] - } - } - }, { - "value": { - "scalarFunction": { - "functionReference": 2, - "outputType": { - "bool": { - "nullability": "NULLABILITY_NULLABLE" - } - }, - "arguments": [{ - "value": { - "selection": { - "directReference": { - "structField": { - "field": 1 - } - } - } - } - }, { - "value": { - "literal": { - "string": "EUROPE" - } - } - }] - } - } - }] - } - } - }, { - "value": { - "scalarFunction": { - "functionReference": 3, - "outputType": { - "bool": { - "nullability": "NULLABILITY_REQUIRED" - } - }, - "arguments": [{ - "value": { - "selection": { - "directReference": { - "structField": { - } - } - } - } - }] - } - } - }] - } - } - } - }, - "expressions": [{ - "selection": { - "directReference": { - "structField": { - } - } - } - }] - } - }, - "names": ["r_regionkey#72"], - "outputSchema": { - "types": [{ - "i64": { - "nullability": "NULLABILITY_NULLABLE" - } - }], - "nullability": "NULLABILITY_REQUIRED" - } - } - }] + /// 0. Create ObjectStorage for HDFS + auto settings = SerializedPlanParser::global_context->getSettingsRef(); + const std::string query + = R"(CREATE TABLE hdfs_engine_xxxx (name String, value UInt32) ENGINE=HDFS('hdfs://localhost:8020/clickhouse/test2', 'Parquet'))"; + DB::ParserCreateQuery parser; + std::string error_message; + const char * pos = query.data(); + auto ast = DB::tryParseQuery( + parser, + pos, + pos + query.size(), + error_message, + /* hilite = */ false, + "QUERY TEST", + /* allow_multi_statements = */ false, + 0, + settings.max_parser_depth, + settings.max_parser_backtracks, + true); + auto & create = ast->as(); + auto arg = create.storage->children[0]; + const auto * func = arg->as(); + EXPECT_TRUE(func && func->name == "HDFS"); + + DB::StorageHDFSConfiguration config; + StorageObjectStorage::Configuration::initialize(config, arg->children[0]->children, SerializedPlanParser::global_context, false); + + const std::shared_ptr object_storage + = std::dynamic_pointer_cast(config.createObjectStorage(SerializedPlanParser::global_context, false)); + EXPECT_TRUE(object_storage != nullptr); + + RelativePathsWithMetadata files_with_metadata; + object_storage->listObjects("/clickhouse", files_with_metadata, 0); + + /// 1. Create ObjectStorageSink + DB::StorageObjectStorageSink sink{ + object_storage, config.clone(), {}, {{STRING(), "name"}, {UINT(), "value"}}, SerializedPlanParser::global_context, ""}; + + /// 2. Create Chunk + /// 3. comsume + sink.consume(testChunk()); + sink.onFinish(); } -)"; + +namespace DB +{ +SinkToStoragePtr createFilelinkSink( + const StorageMetadataPtr & metadata_snapshot, + const String & table_name_for_log, + const String & path, + CompressionMethod compression_method, + const std::optional & format_settings, + const String & format_name, + const ContextPtr & context, + int flags); } -TEST(SerializedPlanParser, PR65234) +INCBIN(resource_embedded_readcsv_json, SOURCE_DIR "/utils/extern-local-engine/tests/json/read_student_option_schema.csv.json"); +TEST(LocalExecutor, StorageFileSink) { const std::string split - = R"({"items":[{"uriFile":"file:///part-00000-16caa751-9774-470c-bd37-5c84c53373c8-c000.snappy.parquet","length":"84633","parquet":{},"schema":{},"metadataColumns":[{}]}]}")"; + = R"({"items":[{"uriFile":"file:///home/chang/SourceCode/rebase_gluten/backends-velox/src/test/resources/datasource/csv/student_option_schema.csv","length":"56","text":{"fieldDelimiter":",","maxBlockSize":"8192","header":"1"},"schema":{"names":["id","name","language"],"struct":{"types":[{"string":{"nullability":"NULLABILITY_NULLABLE"}},{"string":{"nullability":"NULLABILITY_NULLABLE"}},{"string":{"nullability":"NULLABILITY_NULLABLE"}}]}},"metadataColumns":[{}]}]})"; SerializedPlanParser parser(SerializedPlanParser::global_context); - parser.addSplitInfo(splitBinaryFromJson(split)); - parser.parseJson(JsonPlanFor65234()); -} + parser.addSplitInfo(test::pb_util::JsonStringToBinary(split)); + auto local_executor = parser.createExecutor( + {reinterpret_cast(gresource_embedded_readcsv_jsonData), gresource_embedded_readcsv_jsonSize}); + + while (local_executor->hasNext()) + { + const Block & x = *local_executor->nextColumnar(); + EXPECT_EQ(4, x.rows()); + } + + StorageInMemoryMetadata metadata; + metadata.setColumns(ColumnsDescription::fromNamesAndTypes({{"name", STRING()}, {"value", UINT()}})); + StorageMetadataPtr metadata_ptr = std::make_shared(metadata); + + auto sink = createFilelinkSink( + metadata_ptr, + "test_table", + "/tmp/test_table.parquet", + CompressionMethod::None, + {}, + "Parquet", + SerializedPlanParser::global_context, + 0); + + sink->consume(testChunk()); + sink->onFinish(); +} \ No newline at end of file diff --git a/cpp-ch/local-engine/tests/json/clickhouse_pr_65234.json b/cpp-ch/local-engine/tests/json/clickhouse_pr_65234.json new file mode 100644 index 000000000000..1c37b68b7144 --- /dev/null +++ b/cpp-ch/local-engine/tests/json/clickhouse_pr_65234.json @@ -0,0 +1,273 @@ +{ + "extensions": [{ + "extensionFunction": { + "functionAnchor": 1, + "name": "is_not_null:str" + } + }, { + "extensionFunction": { + "functionAnchor": 2, + "name": "equal:str_str" + } + }, { + "extensionFunction": { + "functionAnchor": 3, + "name": "is_not_null:i64" + } + }, { + "extensionFunction": { + "name": "and:bool_bool" + } + }], + "relations": [{ + "root": { + "input": { + "project": { + "common": { + "emit": { + "outputMapping": [2] + } + }, + "input": { + "filter": { + "common": { + "direct": { + } + }, + "input": { + "read": { + "common": { + "direct": { + } + }, + "baseSchema": { + "names": ["r_regionkey", "r_name"], + "struct": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + }, { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }] + }, + "columnTypes": ["NORMAL_COL", "NORMAL_COL"] + }, + "filter": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + } + } + } + }, { + "value": { + "literal": { + "string": "EUROPE" + } + } + }] + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + } + } + } + } + }] + } + } + }] + } + }, + "advancedExtension": { + "optimization": { + "@type": "type.googleapis.com/google.protobuf.StringValue", + "value": "isMergeTree\u003d0\n" + } + } + } + }, + "condition": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "scalarFunction": { + "functionReference": 1, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + } + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 2, + "outputType": { + "bool": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + "field": 1 + } + } + } + } + }, { + "value": { + "literal": { + "string": "EUROPE" + } + } + }] + } + } + }] + } + } + }, { + "value": { + "scalarFunction": { + "functionReference": 3, + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [{ + "value": { + "selection": { + "directReference": { + "structField": { + } + } + } + } + }] + } + } + }] + } + } + } + }, + "expressions": [{ + "selection": { + "directReference": { + "structField": { + } + } + } + }] + } + }, + "names": ["r_regionkey#72"], + "outputSchema": { + "types": [{ + "i64": { + "nullability": "NULLABILITY_NULLABLE" + } + }], + "nullability": "NULLABILITY_REQUIRED" + } + } + }] +} \ No newline at end of file diff --git a/cpp-ch/local-engine/tests/json/gtest_local_engine_config.json b/cpp-ch/local-engine/tests/json/gtest_local_engine_config.json new file mode 100644 index 000000000000..10f0ea3dfdad --- /dev/null +++ b/cpp-ch/local-engine/tests/json/gtest_local_engine_config.json @@ -0,0 +1,269 @@ +{ + "advancedExtensions": { + "enhancement": { + "@type": "type.googleapis.com/substrait.Expression", + "literal": { + "map": { + "keyValues": [ + { + "key": { + "string": "spark.gluten.sql.columnar.backend.ch.runtime_config.logger.level" + }, + "value": { + "string": "test" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.backend.ch.runtime_settings.max_bytes_before_external_sort" + }, + "value": { + "string": "5368709120" + } + }, + { + "key": { + "string": "spark.hadoop.fs.s3a.endpoint" + }, + "value": { + "string": "localhost:9000" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.backend.velox.IOThreads" + }, + "value": { + "string": "0" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.backend.ch.runtime_config.hdfs.input_read_timeout" + }, + "value": { + "string": "180000" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.backend.ch.runtime_settings.query_plan_enable_optimizations" + }, + "value": { + "string": "false" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.backend.ch.worker.id" + }, + "value": { + "string": "1" + } + }, + { + "key": { + "string": "spark.memory.offHeap.enabled" + }, + "value": { + "string": "true" + } + }, + { + "key": { + "string": "spark.hadoop.fs.s3a.iam.role.session.name" + }, + "value": { + "string": "" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.backend.ch.runtime_config.hdfs.input_connect_timeout" + }, + "value": { + "string": "180000" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.shuffle.codec" + }, + "value": { + "string": "" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.backend.ch.runtime_config.local_engine.settings.log_processors_profiles" + }, + "value": { + "string": "true" + } + }, + { + "key": { + "string": "spark.gluten.memory.offHeap.size.in.bytes" + }, + "value": { + "string": "10737418240" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.shuffle.codecBackend" + }, + "value": { + "string": "" + } + }, + { + "key": { + "string": "spark.sql.orc.compression.codec" + }, + "value": { + "string": "snappy" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.backend.ch.runtime_settings.max_bytes_before_external_group_by" + }, + "value": { + "string": "5368709120" + } + }, + { + "key": { + "string": "spark.hadoop.input.write.timeout" + }, + "value": { + "string": "180000" + } + }, + { + "key": { + "string": "spark.hadoop.fs.s3a.secret.key" + }, + "value": { + "string": "" + } + }, + { + "key": { + "string": "spark.hadoop.fs.s3a.access.key" + }, + "value": { + "string": "" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.backend.ch.runtime_config.hdfs.dfs_client_log_severity" + }, + "value": { + "string": "INFO" + } + }, + { + "key": { + "string": "spark.hadoop.fs.s3a.path.style.access" + }, + "value": { + "string": "true" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.backend.ch.runtime_config.timezone" + }, + "value": { + "string": "Asia/Shanghai" + } + }, + { + "key": { + "string": "spark.hadoop.input.read.timeout" + }, + "value": { + "string": "180000" + } + }, + { + "key": { + "string": "spark.hadoop.fs.s3a.use.instance.credentials" + }, + "value": { + "string": "false" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.backend.ch.runtime_settings.output_format_orc_compression_method" + }, + "value": { + "string": "snappy" + } + }, + { + "key": { + "string": "spark.hadoop.fs.s3a.iam.role" + }, + "value": { + "string": "" + } + }, + { + "key": { + "string": "spark.gluten.memory.task.offHeap.size.in.bytes" + }, + "value": { + "string": "10737418240" + } + }, + { + "key": { + "string": "spark.hadoop.input.connect.timeout" + }, + "value": { + "string": "180000" + } + }, + { + "key": { + "string": "spark.hadoop.dfs.client.log.severity" + }, + "value": { + "string": "INFO" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.backend.velox.SplitPreloadPerDriver" + }, + "value": { + "string": "2" + } + }, + { + "key": { + "string": "spark.gluten.sql.columnar.backend.ch.runtime_config.hdfs.input_write_timeout" + }, + "value": { + "string": "180000" + } + }, + { + "key": { + "string": "spark.hadoop.fs.s3a.connection.ssl.enabled" + }, + "value": { + "string": "false" + } + } + ] + } + } + } + } +} \ No newline at end of file diff --git a/cpp-ch/local-engine/tests/json/read_student_option_schema.csv.json b/cpp-ch/local-engine/tests/json/read_student_option_schema.csv.json new file mode 100644 index 000000000000..f9518d39014a --- /dev/null +++ b/cpp-ch/local-engine/tests/json/read_student_option_schema.csv.json @@ -0,0 +1,77 @@ +{ + "relations": [ + { + "root": { + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "id", + "name", + "language" + ], + "struct": { + "types": [ + { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ] + }, + "columnTypes": [ + "NORMAL_COL", + "NORMAL_COL", + "NORMAL_COL" + ] + }, + "advancedExtension": { + "optimization": { + "@type": "type.googleapis.com/google.protobuf.StringValue", + "value": "isMergeTree=0\n" + } + } + } + }, + "names": [ + "id#20", + "name#21", + "language#22" + ], + "outputSchema": { + "types": [ + { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + { + "string": { + "nullability": "NULLABILITY_NULLABLE" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + } + } + ] +} \ No newline at end of file diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala index 9a37c4a40dd1..3ca5e0313924 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala @@ -430,7 +430,9 @@ trait SparkPlanExecApi { * * @return */ - def genExtendedColumnarPostRules(): List[SparkSession => Rule[SparkPlan]] + def genExtendedColumnarPostRules(): List[SparkSession => Rule[SparkPlan]] = { + SparkShimLoader.getSparkShims.getExtendedColumnarPostRules() ::: List() + } def genInjectPostHocResolutionRules(): List[SparkSession => Rule[LogicalPlan]] diff --git a/gluten-core/src/main/scala/org/apache/gluten/utils/SubstraitPlanPrinterUtil.scala b/gluten-core/src/main/scala/org/apache/gluten/utils/SubstraitPlanPrinterUtil.scala index 77d5d55f618d..a6ec7cb21fbf 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/utils/SubstraitPlanPrinterUtil.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/utils/SubstraitPlanPrinterUtil.scala @@ -24,37 +24,34 @@ import io.substrait.proto.{NamedStruct, Plan} object SubstraitPlanPrinterUtil extends Logging { - /** Transform Substrait Plan to json format. */ - def substraitPlanToJson(substraintPlan: Plan): String = { + private def typeRegistry( + d: com.google.protobuf.Descriptors.Descriptor): com.google.protobuf.TypeRegistry = { val defaultRegistry = WrappersProto.getDescriptor.getMessageTypes - val registry = com.google.protobuf.TypeRegistry + com.google.protobuf.TypeRegistry .newBuilder() - .add(substraintPlan.getDescriptorForType()) + .add(d) .add(defaultRegistry) .build() - JsonFormat.printer.usingTypeRegistry(registry).print(substraintPlan) + } + private def MessageToJson(message: com.google.protobuf.Message): String = { + val registry = typeRegistry(message.getDescriptorForType) + JsonFormat.printer.usingTypeRegistry(registry).print(message) } - def substraitNamedStructToJson(substraintPlan: NamedStruct): String = { - val defaultRegistry = WrappersProto.getDescriptor.getMessageTypes - val registry = com.google.protobuf.TypeRegistry - .newBuilder() - .add(substraintPlan.getDescriptorForType()) - .add(defaultRegistry) - .build() - JsonFormat.printer.usingTypeRegistry(registry).print(substraintPlan) + /** Transform Substrait Plan to json format. */ + def substraitPlanToJson(substraitPlan: Plan): String = { + MessageToJson(substraitPlan) + } + + def substraitNamedStructToJson(namedStruct: NamedStruct): String = { + MessageToJson(namedStruct) } /** Transform substrait plan json string to PlanNode */ def jsonToSubstraitPlan(planJson: String): Plan = { try { val builder = Plan.newBuilder() - val defaultRegistry = WrappersProto.getDescriptor.getMessageTypes - val registry = com.google.protobuf.TypeRegistry - .newBuilder() - .add(builder.getDescriptorForType) - .add(defaultRegistry) - .build() + val registry = typeRegistry(builder.getDescriptorForType) JsonFormat.parser().usingTypeRegistry(registry).merge(planJson, builder) builder.build() } catch {