diff --git a/.github/workflows/dev_cron.yml b/.github/workflows/dev_cron.yml index 48ca21510fd9..193549cc077d 100644 --- a/.github/workflows/dev_cron.yml +++ b/.github/workflows/dev_cron.yml @@ -27,15 +27,16 @@ jobs: process: name: Process runs-on: ubuntu-latest + permissions: write-all steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Comment Issues link if: | github.event_name == 'pull_request_target' && (github.event.action == 'opened' || github.event.action == 'edited') - uses: actions/github-script@v3 + uses: actions/github-script@v7 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | @@ -47,7 +48,7 @@ jobs: github.event_name == 'pull_request_target' && (github.event.action == 'opened' || github.event.action == 'edited') - uses: actions/github-script@v3 + uses: actions/github-script@v7 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | diff --git a/.github/workflows/dev_cron/issues_link.js b/.github/workflows/dev_cron/issues_link.js index 596bad758532..e47ecb50a55a 100644 --- a/.github/workflows/dev_cron/issues_link.js +++ b/.github/workflows/dev_cron/issues_link.js @@ -35,7 +35,7 @@ async function haveComment(github, context, pullRequestNumber, body) { page: 1 }; while (true) { - const response = await github.issues.listComments(options); + const response = await github.rest.issues.listComments(options); if (response.data.some(comment => comment.body === body)) { return true; } @@ -52,7 +52,7 @@ async function commentISSUESURL(github, context, pullRequestNumber, issuesID) { if (await haveComment(github, context, pullRequestNumber, issuesURL)) { return; } - await github.issues.createComment({ + await github.rest.issues.createComment({ owner: context.repo.owner, repo: context.repo.repo, issue_number: pullRequestNumber, diff --git a/.github/workflows/dev_cron/title_check.js b/.github/workflows/dev_cron/title_check.js index e553e20b025e..1e6df340f2f2 100644 --- a/.github/workflows/dev_cron/title_check.js +++ b/.github/workflows/dev_cron/title_check.js @@ -25,7 +25,7 @@ function haveISSUESID(title) { } async function commentOpenISSUESIssue(github, context, pullRequestNumber) { - const {data: comments} = await github.issues.listComments({ + const {data: comments} = await github.rest.issues.listComments({ owner: context.repo.owner, repo: context.repo.repo, issue_number: pullRequestNumber, @@ -36,7 +36,7 @@ async function commentOpenISSUESIssue(github, context, pullRequestNumber) { } const commentPath = ".github/workflows/dev_cron/title_check.md"; const comment = fs.readFileSync(commentPath).toString(); - await github.issues.createComment({ + await github.rest.issues.createComment({ owner: context.repo.owner, repo: context.repo.repo, issue_number: pullRequestNumber, diff --git a/.github/workflows/docker_image.yml b/.github/workflows/docker_image.yml index 7eecfdecc58d..6a5697e3aae4 100644 --- a/.github/workflows/docker_image.yml +++ b/.github/workflows/docker_image.yml @@ -16,7 +16,9 @@ name: Build and Push Docker Image on: - pull_request: + push: + branches: + - main paths: - '.github/workflows/docker_image.yml' schedule: diff --git a/backends-clickhouse/src/main/java/org/apache/gluten/execution/CHNativeCacheManager.java b/backends-clickhouse/src/main/java/org/apache/gluten/execution/CHNativeCacheManager.java index f5f75dc1dca6..7b765924fa0d 100644 --- a/backends-clickhouse/src/main/java/org/apache/gluten/execution/CHNativeCacheManager.java +++ b/backends-clickhouse/src/main/java/org/apache/gluten/execution/CHNativeCacheManager.java @@ -19,9 +19,15 @@ import java.util.Set; public class CHNativeCacheManager { - public static void cacheParts(String table, Set columns, boolean async) { - nativeCacheParts(table, String.join(",", columns), async); + public static String cacheParts(String table, Set columns) { + return nativeCacheParts(table, String.join(",", columns)); } - private static native void nativeCacheParts(String table, String columns, boolean async); + private static native String nativeCacheParts(String table, String columns); + + public static CacheResult getCacheStatus(String jobId) { + return nativeGetCacheStatus(jobId); + } + + private static native CacheResult nativeGetCacheStatus(String jobId); } diff --git a/backends-clickhouse/src/main/java/org/apache/gluten/execution/CacheResult.java b/backends-clickhouse/src/main/java/org/apache/gluten/execution/CacheResult.java new file mode 100644 index 000000000000..0fa69e0d0b1f --- /dev/null +++ b/backends-clickhouse/src/main/java/org/apache/gluten/execution/CacheResult.java @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.execution; + +public class CacheResult { + public enum Status { + RUNNING(0), + SUCCESS(1), + ERROR(2); + + private final int value; + + Status(int value) { + this.value = value; + } + + public int getValue() { + return value; + } + + public static Status fromInt(int value) { + for (Status myEnum : Status.values()) { + if (myEnum.getValue() == value) { + return myEnum; + } + } + throw new IllegalArgumentException("No enum constant for value: " + value); + } + } + + private final Status status; + private final String message; + + public CacheResult(int status, String message) { + this.status = Status.fromInt(status); + this.message = message; + } + + public Status getStatus() { + return status; + } + + public String getMessage() { + return message; + } +} diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/rpc/GlutenExecutorEndpoint.scala b/backends-clickhouse/src/main/scala/org/apache/spark/rpc/GlutenExecutorEndpoint.scala index 4d90ab6533ba..8a3bde235887 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/rpc/GlutenExecutorEndpoint.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/rpc/GlutenExecutorEndpoint.scala @@ -64,8 +64,6 @@ class GlutenExecutorEndpoint(val executorId: String, val conf: SparkConf) hashIds.forEach( resource_id => CHBroadcastBuildSideCache.invalidateBroadcastHashtable(resource_id)) } - case GlutenMergeTreeCacheLoad(mergeTreeTable, columns) => - CHNativeCacheManager.cacheParts(mergeTreeTable, columns, true) case e => logError(s"Received unexpected message. $e") @@ -74,12 +72,16 @@ class GlutenExecutorEndpoint(val executorId: String, val conf: SparkConf) override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case GlutenMergeTreeCacheLoad(mergeTreeTable, columns) => try { - CHNativeCacheManager.cacheParts(mergeTreeTable, columns, false) - context.reply(CacheLoadResult(true)) + val jobId = CHNativeCacheManager.cacheParts(mergeTreeTable, columns) + context.reply(CacheJobInfo(status = true, jobId)) } catch { case _: Exception => - context.reply(CacheLoadResult(false, s"executor: $executorId cache data failed.")) + context.reply( + CacheJobInfo(status = false, "", s"executor: $executorId cache data failed.")) } + case GlutenMergeTreeCacheLoadStatus(jobId) => + val status = CHNativeCacheManager.getCacheStatus(jobId) + context.reply(status) case e => logError(s"Received unexpected message. $e") } diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/rpc/GlutenRpcMessages.scala b/backends-clickhouse/src/main/scala/org/apache/spark/rpc/GlutenRpcMessages.scala index d675d705f10a..800b15b9949b 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/rpc/GlutenRpcMessages.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/rpc/GlutenRpcMessages.scala @@ -35,8 +35,12 @@ object GlutenRpcMessages { case class GlutenCleanExecutionResource(executionId: String, broadcastHashIds: util.Set[String]) extends GlutenRpcMessage + // for mergetree cache case class GlutenMergeTreeCacheLoad(mergeTreeTable: String, columns: util.Set[String]) extends GlutenRpcMessage - case class CacheLoadResult(success: Boolean, reason: String = "") extends GlutenRpcMessage + case class GlutenMergeTreeCacheLoadStatus(jobId: String) + + case class CacheJobInfo(status: Boolean, jobId: String, reason: String = "") + extends GlutenRpcMessage } diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/commands/GlutenCHCacheDataCommand.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/commands/GlutenCHCacheDataCommand.scala index 1e6b024063b6..f32d22d5eac0 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/commands/GlutenCHCacheDataCommand.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/commands/GlutenCHCacheDataCommand.scala @@ -17,18 +17,20 @@ package org.apache.spark.sql.execution.commands import org.apache.gluten.exception.GlutenException +import org.apache.gluten.execution.CacheResult +import org.apache.gluten.execution.CacheResult.Status import org.apache.gluten.expression.ConverterUtils import org.apache.gluten.substrait.rel.ExtensionTableBuilder import org.apache.spark.affinity.CHAffinity import org.apache.spark.rpc.GlutenDriverEndpoint -import org.apache.spark.rpc.GlutenRpcMessages.{CacheLoadResult, GlutenMergeTreeCacheLoad} +import org.apache.spark.rpc.GlutenRpcMessages.{CacheJobInfo, GlutenMergeTreeCacheLoad, GlutenMergeTreeCacheLoadStatus} import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, GreaterThanOrEqual, IsNotNull, Literal} import org.apache.spark.sql.delta._ import org.apache.spark.sql.execution.command.LeafRunnableCommand -import org.apache.spark.sql.execution.commands.GlutenCHCacheDataCommand.toExecutorId +import org.apache.spark.sql.execution.commands.GlutenCHCacheDataCommand.{checkExecutorId, collectJobTriggerResult, toExecutorId, waitAllJobFinish, waitRpcResults} import org.apache.spark.sql.execution.datasources.v2.clickhouse.metadata.AddMergeTreeParts import org.apache.spark.sql.types.{BooleanType, StringType} import org.apache.spark.util.ThreadUtils @@ -106,7 +108,8 @@ case class GlutenCHCacheDataCommand( } val selectedAddFiles = if (tsfilter.isDefined) { - val allParts = DeltaAdapter.snapshotFilesForScan(snapshot, Seq.empty, Seq.empty, false) + val allParts = + DeltaAdapter.snapshotFilesForScan(snapshot, Seq.empty, Seq.empty, keepNumRecords = false) allParts.files.filter(_.modificationTime >= tsfilter.get.toLong).toSeq } else if (partitionColumn.isDefined && partitionValue.isDefined) { val partitionColumns = snapshot.metadata.partitionSchema.fieldNames @@ -126,10 +129,12 @@ case class GlutenCHCacheDataCommand( snapshot, Seq(partitionColumnAttr), Seq(isNotNullExpr, greaterThanOrEqual), - false) + keepNumRecords = false) .files } else { - DeltaAdapter.snapshotFilesForScan(snapshot, Seq.empty, Seq.empty, false).files + DeltaAdapter + .snapshotFilesForScan(snapshot, Seq.empty, Seq.empty, keepNumRecords = false) + .files } val executorIdsToAddFiles = @@ -151,9 +156,7 @@ case class GlutenCHCacheDataCommand( if (locations.isEmpty) { // non soft affinity - executorIdsToAddFiles - .get(GlutenCHCacheDataCommand.ALL_EXECUTORS) - .get + executorIdsToAddFiles(GlutenCHCacheDataCommand.ALL_EXECUTORS) .append(mergeTreePart) } else { locations.foreach( @@ -161,7 +164,7 @@ case class GlutenCHCacheDataCommand( if (!executorIdsToAddFiles.contains(executor)) { executorIdsToAddFiles.put(executor, new ArrayBuffer[AddMergeTreeParts]()) } - executorIdsToAddFiles.get(executor).get.append(mergeTreePart) + executorIdsToAddFiles(executor).append(mergeTreePart) }) } }) @@ -201,87 +204,112 @@ case class GlutenCHCacheDataCommand( executorIdsToParts.put(executorId, extensionTableNode.getExtensionTableStr) } }) - - // send rpc call + val futureList = ArrayBuffer[(String, Future[CacheJobInfo])]() if (executorIdsToParts.contains(GlutenCHCacheDataCommand.ALL_EXECUTORS)) { // send all parts to all executors - val tableMessage = executorIdsToParts.get(GlutenCHCacheDataCommand.ALL_EXECUTORS).get - if (asynExecute) { - GlutenDriverEndpoint.executorDataMap.forEach( - (executorId, executor) => { - executor.executorEndpointRef.send( - GlutenMergeTreeCacheLoad(tableMessage, selectedColumns.toSet.asJava)) - }) - Seq(Row(true, "")) - } else { - val futureList = ArrayBuffer[Future[CacheLoadResult]]() - val resultList = ArrayBuffer[CacheLoadResult]() - GlutenDriverEndpoint.executorDataMap.forEach( - (executorId, executor) => { - futureList.append( - executor.executorEndpointRef.ask[CacheLoadResult]( + val tableMessage = executorIdsToParts(GlutenCHCacheDataCommand.ALL_EXECUTORS) + GlutenDriverEndpoint.executorDataMap.forEach( + (executorId, executor) => { + futureList.append( + ( + executorId, + executor.executorEndpointRef.ask[CacheJobInfo]( GlutenMergeTreeCacheLoad(tableMessage, selectedColumns.toSet.asJava) - )) - }) - futureList.foreach( - f => { - resultList.append(ThreadUtils.awaitResult(f, Duration.Inf)) - }) - if (resultList.exists(!_.success)) { - Seq(Row(false, resultList.filter(!_.success).map(_.reason).mkString(";"))) - } else { - Seq(Row(true, "")) - } - } + ))) + }) } else { - if (asynExecute) { - executorIdsToParts.foreach( - value => { - val executorData = GlutenDriverEndpoint.executorDataMap.get(toExecutorId(value._1)) - if (executorData != null) { - executorData.executorEndpointRef.send( - GlutenMergeTreeCacheLoad(value._2, selectedColumns.toSet.asJava)) - } else { - throw new GlutenException( - s"executor ${value._1} not found," + - s" all executors are ${GlutenDriverEndpoint.executorDataMap.toString}") - } - }) - Seq(Row(true, "")) - } else { - val futureList = ArrayBuffer[Future[CacheLoadResult]]() - val resultList = ArrayBuffer[CacheLoadResult]() - executorIdsToParts.foreach( - value => { - val executorData = GlutenDriverEndpoint.executorDataMap.get(toExecutorId(value._1)) - if (executorData != null) { - futureList.append( - executorData.executorEndpointRef.ask[CacheLoadResult]( - GlutenMergeTreeCacheLoad(value._2, selectedColumns.toSet.asJava) - )) - } else { - throw new GlutenException( - s"executor ${value._1} not found," + - s" all executors are ${GlutenDriverEndpoint.executorDataMap.toString}") - } - }) - futureList.foreach( - f => { - resultList.append(ThreadUtils.awaitResult(f, Duration.Inf)) - }) - if (resultList.exists(!_.success)) { - Seq(Row(false, resultList.filter(!_.success).map(_.reason).mkString(";"))) - } else { - Seq(Row(true, "")) - } - } + executorIdsToParts.foreach( + value => { + checkExecutorId(value._1) + val executorData = GlutenDriverEndpoint.executorDataMap.get(toExecutorId(value._1)) + futureList.append( + ( + value._1, + executorData.executorEndpointRef.ask[CacheJobInfo]( + GlutenMergeTreeCacheLoad(value._2, selectedColumns.toSet.asJava) + ))) + }) + } + val resultList = waitRpcResults(futureList) + if (asynExecute) { + val res = collectJobTriggerResult(resultList) + Seq(Row(res._1, res._2.mkString(";"))) + } else { + val res = waitAllJobFinish(resultList) + Seq(Row(res._1, res._2)) } } + } object GlutenCHCacheDataCommand { - val ALL_EXECUTORS = "allExecutors" + private val ALL_EXECUTORS = "allExecutors" private def toExecutorId(executorId: String): String = executorId.split("_").last + + def waitAllJobFinish(jobs: ArrayBuffer[(String, CacheJobInfo)]): (Boolean, String) = { + val res = collectJobTriggerResult(jobs) + var status = res._1 + val messages = res._2 + jobs.foreach( + job => { + if (status) { + var complete = false + while (!complete) { + Thread.sleep(5000) + val future_result = GlutenDriverEndpoint.executorDataMap + .get(toExecutorId(job._1)) + .executorEndpointRef + .ask[CacheResult](GlutenMergeTreeCacheLoadStatus(job._2.jobId)) + val result = ThreadUtils.awaitResult(future_result, Duration.Inf) + result.getStatus match { + case Status.ERROR => + status = false + messages.append( + s"executor : {}, failed with message: {};", + job._1, + result.getMessage) + complete = true + case Status.SUCCESS => + complete = true + case _ => + // still running + } + } + } + }) + (status, messages.mkString(";")) + } + + private def collectJobTriggerResult(jobs: ArrayBuffer[(String, CacheJobInfo)]) = { + var status = true + val messages = ArrayBuffer[String]() + jobs.foreach( + job => { + if (!job._2.status) { + messages.append(job._2.reason) + status = false + } + }) + (status, messages) + } + + private def waitRpcResults = (futureList: ArrayBuffer[(String, Future[CacheJobInfo])]) => { + val resultList = ArrayBuffer[(String, CacheJobInfo)]() + futureList.foreach( + f => { + resultList.append((f._1, ThreadUtils.awaitResult(f._2, Duration.Inf))) + }) + resultList + } + + private def checkExecutorId(executorId: String): Unit = { + if (!GlutenDriverEndpoint.executorDataMap.containsKey(toExecutorId(executorId))) { + throw new GlutenException( + s"executor $executorId not found," + + s" all executors are ${GlutenDriverEndpoint.executorDataMap.toString}") + } + } + } diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseMergeTreeWriteOnS3Suite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseMergeTreeWriteOnS3Suite.scala index 6a473cc54f7e..87e95cbe9dda 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseMergeTreeWriteOnS3Suite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseMergeTreeWriteOnS3Suite.scala @@ -188,20 +188,33 @@ class GlutenClickHouseMergeTreeWriteOnS3Suite var metadataGlutenExist: Boolean = false var metadataBinExist: Boolean = false var dataBinExist: Boolean = false + var hasCommits = false client .listObjects(args) .forEach( obj => { objectCount += 1 - if (obj.get().objectName().contains("metadata.gluten")) { + val objectName = obj.get().objectName() + if (objectName.contains("metadata.gluten")) { metadataGlutenExist = true - } else if (obj.get().objectName().contains("meta.bin")) { + } else if (objectName.contains("meta.bin")) { metadataBinExist = true - } else if (obj.get().objectName().contains("data.bin")) { + } else if (objectName.contains("data.bin")) { dataBinExist = true + } else if (objectName.contains("_commits")) { + // Spark 35 has _commits directory + // table/_delta_log/_commits/ + hasCommits = true } }) - assertResult(5)(objectCount) + + if (isSparkVersionGE("3.5")) { + assertResult(6)(objectCount) + assert(hasCommits) + } else { + assertResult(5)(objectCount) + } + assert(metadataGlutenExist) assert(metadataBinExist) assert(dataBinExist) diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseWholeStageTransformerSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseWholeStageTransformerSuite.scala index 4972861152fd..f914eaa1860a 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseWholeStageTransformerSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseWholeStageTransformerSuite.scala @@ -178,11 +178,13 @@ class GlutenClickHouseWholeStageTransformerSuite extends WholeStageTransformerSu super.beforeAll() } - protected val rootPath: String = this.getClass.getResource("/").getPath - protected val basePath: String = rootPath + "tests-working-home" - protected val warehouse: String = basePath + "/spark-warehouse" - protected val metaStorePathAbsolute: String = basePath + "/meta" - protected val hiveMetaStoreDB: String = metaStorePathAbsolute + "/metastore_db" + final protected val rootPath: String = this.getClass.getResource("/").getPath + final protected val basePath: String = rootPath + "tests-working-home" + final protected val warehouse: String = basePath + "/spark-warehouse" + final protected val metaStorePathAbsolute: String = basePath + "/meta" + + protected val hiveMetaStoreDB: String = + s"$metaStorePathAbsolute/${getClass.getSimpleName}/metastore_db" final override protected val resourcePath: String = "" // ch not need this override protected val fileFormat: String = "parquet" diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickhouseCountDistinctSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickhouseCountDistinctSuite.scala index 28ff5874fabd..383681733026 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickhouseCountDistinctSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickhouseCountDistinctSuite.scala @@ -16,7 +16,8 @@ */ package org.apache.gluten.execution -import org.apache.gluten.execution.AllDataTypesWithComplexType.genTestData +import org.apache.gluten.test.AllDataTypesWithComplexType +import org.apache.gluten.test.AllDataTypesWithComplexType.genTestData import org.apache.spark.SparkConf class GlutenClickhouseCountDistinctSuite extends GlutenClickHouseWholeStageTransformerSuite { diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickhouseFunctionSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickhouseFunctionSuite.scala index 1d4d1b6f8afb..ac18f256e807 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickhouseFunctionSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickhouseFunctionSuite.scala @@ -20,12 +20,6 @@ import org.apache.gluten.GlutenConfig import org.apache.gluten.utils.UTSystemParameters import org.apache.spark.SparkConf -import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.delta.DeltaLog - -import org.apache.commons.io.FileUtils - -import java.io.File class GlutenClickhouseFunctionSuite extends GlutenClickHouseTPCHAbstractSuite { override protected val needCopyParquetToTablePath = true @@ -39,9 +33,6 @@ class GlutenClickhouseFunctionSuite extends GlutenClickHouseTPCHAbstractSuite { createNotNullTPCHTablesInParquet(tablesPath) } - private var _hiveSpark: SparkSession = _ - override protected def spark: SparkSession = _hiveSpark - override protected def sparkConf: SparkConf = { new SparkConf() .set("spark.plugins", "org.apache.gluten.GlutenPlugin") @@ -69,70 +60,21 @@ class GlutenClickhouseFunctionSuite extends GlutenClickHouseTPCHAbstractSuite { .setMaster("local[1]") } - override protected def initializeSession(): Unit = { - if (_hiveSpark == null) { - val hiveMetaStoreDB = metaStorePathAbsolute + "/metastore_db" - _hiveSpark = SparkSession - .builder() - .config(sparkConf) - .enableHiveSupport() - .config( - "javax.jdo.option.ConnectionURL", - s"jdbc:derby:;databaseName=$hiveMetaStoreDB;create=true") - .getOrCreate() - } - } - - override def beforeAll(): Unit = { - // prepare working paths - val basePathDir = new File(basePath) - if (basePathDir.exists()) { - FileUtils.forceDelete(basePathDir) - } - FileUtils.forceMkdir(basePathDir) - FileUtils.forceMkdir(new File(warehouse)) - FileUtils.forceMkdir(new File(metaStorePathAbsolute)) - FileUtils.copyDirectory(new File(rootPath + resourcePath), new File(tablesPath)) - super.beforeAll() - } - - override protected def afterAll(): Unit = { - DeltaLog.clearCache() - - try { - super.afterAll() - } finally { - try { - if (_hiveSpark != null) { - try { - _hiveSpark.sessionState.catalog.reset() - } finally { - _hiveSpark.stop() - _hiveSpark = null - } - } - } finally { - SparkSession.clearActiveSession() - SparkSession.clearDefaultSession() - } - } - } - test("test uuid - write and read") { withSQLConf( ("spark.gluten.sql.native.writer.enabled", "true"), (GlutenConfig.GLUTEN_ENABLED.key, "true")) { + withTable("uuid_test") { + spark.sql("create table if not exists uuid_test (id string) using parquet") - spark.sql("drop table if exists uuid_test") - spark.sql("create table if not exists uuid_test (id string) stored as parquet") - - val df = spark.sql("select regexp_replace(uuid(), '-', '') as id from range(1)") - df.cache() - df.write.insertInto("uuid_test") + val df = spark.sql("select regexp_replace(uuid(), '-', '') as id from range(1)") + df.cache() + df.write.insertInto("uuid_test") - val df2 = spark.table("uuid_test") - val diffCount = df.exceptAll(df2).count() - assert(diffCount == 0) + val df2 = spark.table("uuid_test") + val diffCount = df.exceptAll(df2).count() + assert(diffCount == 0) + } } } @@ -181,49 +123,51 @@ class GlutenClickhouseFunctionSuite extends GlutenClickHouseTPCHAbstractSuite { } test("GLUTEN-5981 null value from get_json_object") { - spark.sql("create table json_t1 (a string) using parquet") - spark.sql("insert into json_t1 values ('{\"a\":null}')") - runQueryAndCompare( - """ - |SELECT get_json_object(a, '$.a') is null from json_t1 - |""".stripMargin - )(df => checkFallbackOperators(df, 0)) - spark.sql("drop table json_t1") + withTable("json_t1") { + spark.sql("create table json_t1 (a string) using parquet") + spark.sql("insert into json_t1 values ('{\"a\":null}')") + runQueryAndCompare( + """ + |SELECT get_json_object(a, '$.a') is null from json_t1 + |""".stripMargin + )(df => checkFallbackOperators(df, 0)) + } } test("Fix arrayDistinct(Array(Nullable(Decimal))) core dump") { - val create_sql = - """ - |create table if not exists test( - | dec array - |) using parquet - |""".stripMargin - val fill_sql = - """ - |insert into test values(array(1, 2, null)), (array(null, 2,3, 5)) - |""".stripMargin - val query_sql = - """ - |select array_distinct(dec) from test; - |""".stripMargin - spark.sql(create_sql) - spark.sql(fill_sql) - compareResultsAgainstVanillaSpark(query_sql, true, { _ => }) - spark.sql("drop table test") + withTable("json_t1") { + val create_sql = + """ + |create table if not exists test( + | dec array + |) using parquet + |""".stripMargin + val fill_sql = + """ + |insert into test values(array(1, 2, null)), (array(null, 2,3, 5)) + |""".stripMargin + val query_sql = + """ + |select array_distinct(dec) from test; + |""".stripMargin + spark.sql(create_sql) + spark.sql(fill_sql) + compareResultsAgainstVanillaSpark(query_sql, true, { _ => }) + } } test("intersect all") { - spark.sql("create table t1 (a int, b string) using parquet") - spark.sql("insert into t1 values (1, '1'),(2, '2'),(3, '3'),(4, '4'),(5, '5'),(6, '6')") - spark.sql("create table t2 (a int, b string) using parquet") - spark.sql("insert into t2 values (4, '4'),(5, '5'),(6, '6'),(7, '7'),(8, '8'),(9, '9')") - runQueryAndCompare( - """ - |SELECT a,b FROM t1 INTERSECT ALL SELECT a,b FROM t2 - |""".stripMargin - )(df => checkFallbackOperators(df, 0)) - spark.sql("drop table t1") - spark.sql("drop table t2") + withTable("t1", "t2") { + spark.sql("create table t1 (a int, b string) using parquet") + spark.sql("insert into t1 values (1, '1'),(2, '2'),(3, '3'),(4, '4'),(5, '5'),(6, '6')") + spark.sql("create table t2 (a int, b string) using parquet") + spark.sql("insert into t2 values (4, '4'),(5, '5'),(6, '6'),(7, '7'),(8, '8'),(9, '9')") + runQueryAndCompare( + """ + |SELECT a,b FROM t1 INTERSECT ALL SELECT a,b FROM t2 + |""".stripMargin + )(df => checkFallbackOperators(df, 0)) + } } test("array decimal32 CH column to row") { diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseHiveTableSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseHiveTableSuite.scala similarity index 94% rename from backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseHiveTableSuite.scala rename to backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseHiveTableSuite.scala index 83bc4e76b1bd..cc9155613343 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseHiveTableSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseHiveTableSuite.scala @@ -14,13 +14,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.gluten.execution +package org.apache.gluten.execution.hive import org.apache.gluten.GlutenConfig +import org.apache.gluten.execution.{GlutenClickHouseWholeStageTransformerSuite, ProjectExecTransformer, TransformSupport} +import org.apache.gluten.test.AllDataTypesWithComplexType import org.apache.gluten.utils.UTSystemParameters import org.apache.spark.SparkConf -import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession} +import org.apache.spark.sql.{DataFrame, SaveMode} import org.apache.spark.sql.delta.DeltaLog import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.hive.HiveTableScanExecTransformer @@ -29,64 +31,14 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.hadoop.fs.Path import java.io.{File, PrintWriter} -import java.sql.{Date, Timestamp} import scala.reflect.ClassTag -case class AllDataTypesWithComplexType( - string_field: String = null, - int_field: java.lang.Integer = null, - long_field: java.lang.Long = null, - float_field: java.lang.Float = null, - double_field: java.lang.Double = null, - short_field: java.lang.Short = null, - byte_field: java.lang.Byte = null, - boolean_field: java.lang.Boolean = null, - decimal_field: java.math.BigDecimal = null, - date_field: java.sql.Date = null, - timestamp_field: java.sql.Timestamp = null, - array: Seq[Int] = null, - arrayContainsNull: Seq[Option[Int]] = null, - map: Map[Int, Long] = null, - mapValueContainsNull: Map[Int, Option[Long]] = null -) - -object AllDataTypesWithComplexType { - def genTestData(): Seq[AllDataTypesWithComplexType] = { - (0 to 199).map { - i => - if (i % 100 == 1) { - AllDataTypesWithComplexType() - } else { - AllDataTypesWithComplexType( - s"$i", - i, - i.toLong, - i.toFloat, - i.toDouble, - i.toShort, - i.toByte, - i % 2 == 0, - new java.math.BigDecimal(i + ".56"), - Date.valueOf(new Date(System.currentTimeMillis()).toLocalDate.plusDays(i % 10)), - Timestamp.valueOf( - new Timestamp(System.currentTimeMillis()).toLocalDateTime.plusDays(i % 10)), - Seq.apply(i + 1, i + 2, i + 3), - Seq.apply(Option.apply(i + 1), Option.empty, Option.apply(i + 3)), - Map.apply((i + 1, i + 2), (i + 3, i + 4)), - Map.empty - ) - } - } - } -} - class GlutenClickHouseHiveTableSuite extends GlutenClickHouseWholeStageTransformerSuite + with ReCreateHiveSession with AdaptiveSparkPlanHelper { - private var _hiveSpark: SparkSession = _ - override protected def sparkConf: SparkConf = { new SparkConf() .set("spark.plugins", "org.apache.gluten.GlutenPlugin") @@ -119,22 +71,6 @@ class GlutenClickHouseHiveTableSuite .setMaster("local[*]") } - override protected def spark: SparkSession = _hiveSpark - - override protected def initializeSession(): Unit = { - if (_hiveSpark == null) { - val hiveMetaStoreDB = metaStorePathAbsolute + "/metastore_db" - _hiveSpark = SparkSession - .builder() - .config(sparkConf) - .enableHiveSupport() - .config( - "javax.jdo.option.ConnectionURL", - s"jdbc:derby:;databaseName=$hiveMetaStoreDB;create=true") - .getOrCreate() - } - } - private val txt_table_name = "hive_txt_test" private val txt_user_define_input = "hive_txt_user_define_input" private val json_table_name = "hive_json_test" @@ -235,24 +171,7 @@ class GlutenClickHouseHiveTableSuite override protected def afterAll(): Unit = { DeltaLog.clearCache() - - try { - super.afterAll() - } finally { - try { - if (_hiveSpark != null) { - try { - _hiveSpark.sessionState.catalog.reset() - } finally { - _hiveSpark.stop() - _hiveSpark = null - } - } - } finally { - SparkSession.clearActiveSession() - SparkSession.clearDefaultSession() - } - } + super.afterAll() } test("test hive text table") { @@ -957,7 +876,7 @@ class GlutenClickHouseHiveTableSuite val select_sql_4 = "select id, get_json_object(data, '$.v111') from test_tbl_3337" val select_sql_5 = "select id, get_json_object(data, 'v112') from test_tbl_3337" val select_sql_6 = - "select id, get_json_object(data, '$.id') from test_tbl_3337 where id = 123"; + "select id, get_json_object(data, '$.id') from test_tbl_3337 where id = 123" compareResultsAgainstVanillaSpark(select_sql_1, compareResult = true, _ => {}) compareResultsAgainstVanillaSpark(select_sql_2, compareResult = true, _ => {}) compareResultsAgainstVanillaSpark(select_sql_3, compareResult = true, _ => {}) @@ -1311,7 +1230,7 @@ class GlutenClickHouseHiveTableSuite .format(dataPath) val select_sql = "select * from test_tbl_6506" spark.sql(create_table_sql) - compareResultsAgainstVanillaSpark(select_sql, true, _ => {}) + compareResultsAgainstVanillaSpark(select_sql, compareResult = true, _ => {}) spark.sql("drop table test_tbl_6506") } diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseNativeWriteTableSuite.scala similarity index 96% rename from backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala rename to backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseNativeWriteTableSuite.scala index 652b15fc2da0..9e3fa00787de 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseNativeWriteTableSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseNativeWriteTableSuite.scala @@ -14,33 +14,27 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.gluten.execution +package org.apache.gluten.execution.hive import org.apache.gluten.GlutenConfig -import org.apache.gluten.execution.AllDataTypesWithComplexType.genTestData +import org.apache.gluten.execution.GlutenClickHouseWholeStageTransformerSuite +import org.apache.gluten.test.AllDataTypesWithComplexType.genTestData import org.apache.gluten.utils.UTSystemParameters import org.apache.spark.SparkConf import org.apache.spark.gluten.NativeWriteChecker -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.delta.DeltaLog import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.test.SharedSparkSession -import org.apache.spark.sql.types.{DecimalType, LongType, StringType, StructField, StructType} - -import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.types._ import scala.reflect.runtime.universe.TypeTag class GlutenClickHouseNativeWriteTableSuite extends GlutenClickHouseWholeStageTransformerSuite with AdaptiveSparkPlanHelper - with SharedSparkSession - with BeforeAndAfterAll + with ReCreateHiveSession with NativeWriteChecker { - private var _hiveSpark: SparkSession = _ - override protected def sparkConf: SparkConf = { var sessionTimeZone = "GMT" if (isSparkVersionGE("3.5")) { @@ -80,45 +74,12 @@ class GlutenClickHouseNativeWriteTableSuite basePath + "/中文/spark-warehouse" } - override protected def spark: SparkSession = _hiveSpark - - override protected def initializeSession(): Unit = { - if (_hiveSpark == null) { - val hiveMetaStoreDB = metaStorePathAbsolute + "/metastore_db" - _hiveSpark = SparkSession - .builder() - .config(sparkConf) - .enableHiveSupport() - .config( - "javax.jdo.option.ConnectionURL", - s"jdbc:derby:;databaseName=$hiveMetaStoreDB;create=true") - .getOrCreate() - } - } - private val table_name_template = "hive_%s_test" private val table_name_vanilla_template = "hive_%s_test_written_by_vanilla" override protected def afterAll(): Unit = { DeltaLog.clearCache() - - try { - super.afterAll() - } finally { - try { - if (_hiveSpark != null) { - try { - _hiveSpark.sessionState.catalog.reset() - } finally { - _hiveSpark.stop() - _hiveSpark = null - } - } - } finally { - SparkSession.clearActiveSession() - SparkSession.clearDefaultSession() - } - } + super.afterAll() } def getColumnName(s: String): String = { diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTableAfterRestart.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseTableAfterRestart.scala similarity index 87% rename from backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTableAfterRestart.scala rename to backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseTableAfterRestart.scala index f9e831cb4aa7..d359428d03ca 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/GlutenClickHouseTableAfterRestart.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseTableAfterRestart.scala @@ -14,12 +14,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.gluten.execution +package org.apache.gluten.execution.hive + +import org.apache.gluten.execution.GlutenClickHouseTPCHAbstractSuite import org.apache.spark.SparkConf import org.apache.spark.sql.SparkSession import org.apache.spark.sql.SparkSession.{getActiveSession, getDefaultSession} -import org.apache.spark.sql.delta.{ClickhouseSnapshot, DeltaLog} +import org.apache.spark.sql.delta.ClickhouseSnapshot import org.apache.spark.sql.delta.catalog.ClickHouseTableV2 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -33,7 +35,8 @@ import java.io.File // This suite is to make sure clickhouse commands works well even after spark restart class GlutenClickHouseTableAfterRestart extends GlutenClickHouseTPCHAbstractSuite - with AdaptiveSparkPlanHelper { + with AdaptiveSparkPlanHelper + with ReCreateHiveSession { override protected val needCopyParquetToTablePath = true @@ -64,56 +67,18 @@ class GlutenClickHouseTableAfterRestart .set( "spark.gluten.sql.columnar.backend.ch.runtime_settings.input_format_parquet_max_block_size", "8192") + .setMaster("local[2]") } override protected def createTPCHNotNullTables(): Unit = { createNotNullTPCHTablesInParquet(tablesPath) } - private var _hiveSpark: SparkSession = _ - override protected def spark: SparkSession = _hiveSpark - - override protected def initializeSession(): Unit = { - if (_hiveSpark == null) { - val hiveMetaStoreDB = metaStorePathAbsolute + "/metastore_db_" + current_db_num - current_db_num += 1 - - _hiveSpark = SparkSession - .builder() - .config(sparkConf) - .enableHiveSupport() - .config( - "javax.jdo.option.ConnectionURL", - s"jdbc:derby:;databaseName=$hiveMetaStoreDB;create=true") - .master("local[2]") - .getOrCreate() - } - } - - override protected def afterAll(): Unit = { - DeltaLog.clearCache() - - try { - super.afterAll() - } finally { - try { - if (_hiveSpark != null) { - try { - _hiveSpark.sessionState.catalog.reset() - } finally { - _hiveSpark.stop() - _hiveSpark = null - } - } - } finally { - SparkSession.clearActiveSession() - SparkSession.clearDefaultSession() - } - } - } - var current_db_num: Int = 0 + override protected val hiveMetaStoreDB: String = + metaStorePathAbsolute + "/metastore_db_" + current_db_num + test("test mergetree after restart") { spark.sql(s""" |DROP TABLE IF EXISTS lineitem_mergetree; @@ -347,22 +312,22 @@ class GlutenClickHouseTableAfterRestart SparkSession.clearDefaultSession() } - val hiveMetaStoreDB = metaStorePathAbsolute + "/metastore_db_" + val metaStoreDB = metaStorePathAbsolute + "/metastore_db_" // use metastore_db2 to avoid issue: "Another instance of Derby may have already booted the database" - val destDir = new File(hiveMetaStoreDB + current_db_num) - destDir.mkdirs() - FileUtils.copyDirectory(new File(hiveMetaStoreDB + (current_db_num - 1)), destDir) - _hiveSpark = null - _hiveSpark = SparkSession - .builder() - .config(sparkConf) - .enableHiveSupport() - .config( - "javax.jdo.option.ConnectionURL", - s"jdbc:derby:;databaseName=$hiveMetaStoreDB$current_db_num") - .master("local[2]") - .getOrCreate() current_db_num += 1 + val destDir = new File(metaStoreDB + current_db_num) + destDir.mkdirs() + FileUtils.copyDirectory(new File(metaStoreDB + (current_db_num - 1)), destDir) + updateHiveSession( + SparkSession + .builder() + .config(sparkConf) + .enableHiveSupport() + .config( + "javax.jdo.option.ConnectionURL", + s"jdbc:derby:;databaseName=$metaStoreDB$current_db_num") + .getOrCreate() + ) } } // scalastyle:off line.size.limit diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/ReCreateHiveSession.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/ReCreateHiveSession.scala new file mode 100644 index 000000000000..c251e46364f5 --- /dev/null +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/ReCreateHiveSession.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.execution.hive + +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.test.SharedSparkSession + +import org.scalatest.BeforeAndAfterAll + +trait ReCreateHiveSession extends SharedSparkSession with BeforeAndAfterAll { + + protected val hiveMetaStoreDB: String + + private var _hiveSpark: SparkSession = _ + + override protected def spark: SparkSession = _hiveSpark + + override protected def initializeSession(): Unit = { + if (_hiveSpark == null) { + _hiveSpark = SparkSession + .builder() + .config(sparkConf) + .enableHiveSupport() + .config( + "javax.jdo.option.ConnectionURL", + s"jdbc:derby:;databaseName=$hiveMetaStoreDB;create=true") + .getOrCreate() + } + } + + override protected def afterAll(): Unit = { + try { + super.afterAll() + } finally { + try { + if (_hiveSpark != null) { + try { + _hiveSpark.sessionState.catalog.reset() + } finally { + _hiveSpark.stop() + _hiveSpark = null + } + } + } finally { + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + } + } + } + + protected def updateHiveSession(newSession: SparkSession): Unit = { + _hiveSpark = null + _hiveSpark = newSession + } +} diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/parquet/GlutenParquetFilterSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/parquet/GlutenParquetFilterSuite.scala index 1e6509c00884..0a8d1729c810 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/parquet/GlutenParquetFilterSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/parquet/GlutenParquetFilterSuite.scala @@ -35,12 +35,6 @@ class GlutenParquetFilterSuite with GlutenTPCHBase with Logging { - override protected val rootPath = this.getClass.getResource("/").getPath - override protected val basePath = rootPath + "tests-working-home" - override protected val warehouse = basePath + "/spark-warehouse" - override protected val metaStorePathAbsolute = basePath + "/meta" - override protected val hiveMetaStoreDB = metaStorePathAbsolute + "/metastore_db" - private val tpchQueriesResourceFolder: String = rootPath + "../../../../gluten-core/src/test/resources/tpch-queries" diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/test/AllDataTypesWithComplexType.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/test/AllDataTypesWithComplexType.scala new file mode 100644 index 000000000000..19abcbea433a --- /dev/null +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/test/AllDataTypesWithComplexType.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.test + +import java.sql.{Date, Timestamp} + +case class AllDataTypesWithComplexType( + string_field: String = null, + int_field: java.lang.Integer = null, + long_field: java.lang.Long = null, + float_field: java.lang.Float = null, + double_field: java.lang.Double = null, + short_field: java.lang.Short = null, + byte_field: java.lang.Byte = null, + boolean_field: java.lang.Boolean = null, + decimal_field: java.math.BigDecimal = null, + date_field: java.sql.Date = null, + timestamp_field: java.sql.Timestamp = null, + array: Seq[Int] = null, + arrayContainsNull: Seq[Option[Int]] = null, + map: Map[Int, Long] = null, + mapValueContainsNull: Map[Int, Option[Long]] = null +) + +object AllDataTypesWithComplexType { + def genTestData(): Seq[AllDataTypesWithComplexType] = { + (0 to 199).map { + i => + if (i % 100 == 1) { + AllDataTypesWithComplexType() + } else { + AllDataTypesWithComplexType( + s"$i", + i, + i.toLong, + i.toFloat, + i.toDouble, + i.toShort, + i.toByte, + i % 2 == 0, + new java.math.BigDecimal(i + ".56"), + Date.valueOf(new Date(System.currentTimeMillis()).toLocalDate.plusDays(i % 10)), + Timestamp.valueOf( + new Timestamp(System.currentTimeMillis()).toLocalDateTime.plusDays(i % 10)), + Seq.apply(i + 1, i + 2, i + 3), + Seq.apply(Option.apply(i + 1), Option.empty, Option.apply(i + 3)), + Map.apply((i + 1, i + 2), (i + 3, i + 4)), + Map.empty + ) + } + } + } +} diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala index 86925fd1d6a8..2cfc4e9a9099 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala @@ -18,7 +18,6 @@ package org.apache.gluten.backendsapi.velox import org.apache.gluten.GlutenConfig import org.apache.gluten.backendsapi.ListenerApi -import org.apache.gluten.exception.GlutenException import org.apache.gluten.execution.datasource.{GlutenOrcWriterInjects, GlutenParquetWriterInjects, GlutenRowSplitter} import org.apache.gluten.expression.UDFMappings import org.apache.gluten.init.NativeBackendInitializer @@ -27,138 +26,76 @@ import org.apache.gluten.vectorized.{JniLibLoader, JniWorkspace} import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.api.plugin.PluginContext +import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.datasources.velox.{VeloxOrcWriterInjects, VeloxParquetWriterInjects, VeloxRowSplitter} import org.apache.spark.sql.expression.UDFResolver import org.apache.spark.sql.internal.{GlutenConfigUtil, StaticSQLConf} -import org.apache.spark.util.SparkDirectoryUtil +import org.apache.spark.util.{SparkDirectoryUtil, SparkResourceUtil} import org.apache.commons.lang3.StringUtils -import scala.sys.process._ +import java.util.concurrent.atomic.AtomicBoolean -class VeloxListenerApi extends ListenerApi { - private val ARROW_VERSION = "1500" +class VeloxListenerApi extends ListenerApi with Logging { + import VeloxListenerApi._ override def onDriverStart(sc: SparkContext, pc: PluginContext): Unit = { val conf = pc.conf() - // sql table cache serializer + + // Sql table cache serializer. if (conf.getBoolean(GlutenConfig.COLUMNAR_TABLE_CACHE_ENABLED.key, defaultValue = false)) { conf.set( StaticSQLConf.SPARK_CACHE_SERIALIZER.key, "org.apache.spark.sql.execution.ColumnarCachedBatchSerializer") } - initialize(conf, isDriver = true) + + // Static initializers for driver. + if (!driverInitialized.compareAndSet(false, true)) { + // Make sure we call the static initializers only once. + logInfo( + "Skip rerunning static initializers since they are only supposed to run once." + + " You see this message probably because you are creating a new SparkSession.") + return + } + + SparkDirectoryUtil.init(conf) + UDFResolver.resolveUdfConf(conf, isDriver = true) + initialize(conf) } override def onDriverShutdown(): Unit = shutdown() override def onExecutorStart(pc: PluginContext): Unit = { - initialize(pc.conf(), isDriver = false) - } - - override def onExecutorShutdown(): Unit = shutdown() + val conf = pc.conf() - private def getLibraryLoaderForOS( - systemName: String, - systemVersion: String, - system: String): SharedLibraryLoader = { - if (systemName.contains("Ubuntu") && systemVersion.startsWith("20.04")) { - new SharedLibraryLoaderUbuntu2004 - } else if (systemName.contains("Ubuntu") && systemVersion.startsWith("22.04")) { - new SharedLibraryLoaderUbuntu2204 - } else if (systemName.contains("CentOS") && systemVersion.startsWith("9")) { - new SharedLibraryLoaderCentos9 - } else if (systemName.contains("CentOS") && systemVersion.startsWith("8")) { - new SharedLibraryLoaderCentos8 - } else if (systemName.contains("CentOS") && systemVersion.startsWith("7")) { - new SharedLibraryLoaderCentos7 - } else if (systemName.contains("Alibaba Cloud Linux") && systemVersion.startsWith("3")) { - new SharedLibraryLoaderCentos8 - } else if (systemName.contains("Alibaba Cloud Linux") && systemVersion.startsWith("2")) { - new SharedLibraryLoaderCentos7 - } else if (systemName.contains("Anolis") && systemVersion.startsWith("8")) { - new SharedLibraryLoaderCentos8 - } else if (systemName.contains("Anolis") && systemVersion.startsWith("7")) { - new SharedLibraryLoaderCentos7 - } else if (system.contains("tencentos") && system.contains("2.4")) { - new SharedLibraryLoaderCentos7 - } else if (system.contains("tencentos") && system.contains("3.2")) { - new SharedLibraryLoaderCentos8 - } else if (systemName.contains("Red Hat") && systemVersion.startsWith("9")) { - new SharedLibraryLoaderCentos9 - } else if (systemName.contains("Red Hat") && systemVersion.startsWith("8")) { - new SharedLibraryLoaderCentos8 - } else if (systemName.contains("Red Hat") && systemVersion.startsWith("7")) { - new SharedLibraryLoaderCentos7 - } else if (systemName.contains("Debian") && systemVersion.startsWith("11")) { - new SharedLibraryLoaderDebian11 - } else if (systemName.contains("Debian") && systemVersion.startsWith("12")) { - new SharedLibraryLoaderDebian12 - } else { - throw new GlutenException( - s"Found unsupported OS($systemName, $systemVersion)! Currently, Gluten's Velox backend" + - " only supports Ubuntu 20.04/22.04, CentOS 7/8, " + - "Alibaba Cloud Linux 2/3 & Anolis 7/8, tencentos 2.4/3.2, RedHat 7/8, " + - "Debian 11/12.") + // Static initializers for executor. + if (!executorInitialized.compareAndSet(false, true)) { + // Make sure we call the static initializers only once. + logInfo( + "Skip rerunning static initializers since they are only supposed to run once." + + " You see this message probably because you are creating a new SparkSession.") + return } - } - - private def loadLibFromJar(load: JniLibLoader, conf: SparkConf): Unit = { - val systemName = conf.getOption(GlutenConfig.GLUTEN_LOAD_LIB_OS) - val loader = if (systemName.isDefined) { - val systemVersion = conf.getOption(GlutenConfig.GLUTEN_LOAD_LIB_OS_VERSION) - if (systemVersion.isEmpty) { - throw new GlutenException( - s"${GlutenConfig.GLUTEN_LOAD_LIB_OS_VERSION} must be specified when specifies the " + - s"${GlutenConfig.GLUTEN_LOAD_LIB_OS}") - } - getLibraryLoaderForOS(systemName.get, systemVersion.get, "") - } else { - val system = "cat /etc/os-release".!! - val systemNamePattern = "^NAME=\"?(.*)\"?".r - val systemVersionPattern = "^VERSION=\"?(.*)\"?".r - val systemInfoLines = system.stripMargin.split("\n") - val systemNamePattern(systemName) = - systemInfoLines.find(_.startsWith("NAME=")).getOrElse("") - val systemVersionPattern(systemVersion) = - systemInfoLines.find(_.startsWith("VERSION=")).getOrElse("") - if (systemName.isEmpty || systemVersion.isEmpty) { - throw new GlutenException("Failed to get OS name and version info.") - } - getLibraryLoaderForOS(systemName, systemVersion, system) + if (inLocalMode(conf)) { + // Don't do static initializations from executor side in local mode. + // Driver already did that. + logInfo( + "Gluten is running with Spark local mode. Skip running static initializer for executor.") + return } - loader.loadLib(load) - } - private def loadLibWithLinux(conf: SparkConf, loader: JniLibLoader): Unit = { - if ( - conf.getBoolean( - GlutenConfig.GLUTEN_LOAD_LIB_FROM_JAR, - GlutenConfig.GLUTEN_LOAD_LIB_FROM_JAR_DEFAULT) - ) { - loadLibFromJar(loader, conf) - } + SparkDirectoryUtil.init(conf) + UDFResolver.resolveUdfConf(conf, isDriver = false) + initialize(conf) } - private def loadLibWithMacOS(loader: JniLibLoader): Unit = { - // Placeholder for loading shared libs on MacOS if user needs. - } + override def onExecutorShutdown(): Unit = shutdown() - private def initialize(conf: SparkConf, isDriver: Boolean): Unit = { - SparkDirectoryUtil.init(conf) - UDFResolver.resolveUdfConf(conf, isDriver = isDriver) + private def initialize(conf: SparkConf): Unit = { if (conf.getBoolean(GlutenConfig.GLUTEN_DEBUG_KEEP_JNI_WORKSPACE, defaultValue = false)) { val debugDir = conf.get(GlutenConfig.GLUTEN_DEBUG_KEEP_JNI_WORKSPACE_DIR) JniWorkspace.enableDebug(debugDir) } - val loader = JniWorkspace.getDefault.libLoader - - val osName = System.getProperty("os.name") - if (osName.startsWith("Mac OS X") || osName.startsWith("macOS")) { - loadLibWithMacOS(loader) - } else { - loadLibWithLinux(conf, loader) - } // Set the system properties. // Use appending policy for children with the same name in a arrow struct vector. @@ -167,6 +104,13 @@ class VeloxListenerApi extends ListenerApi { // Load supported hive/python/scala udfs UDFMappings.loadFromSparkConf(conf) + // Initial library loader. + val loader = JniWorkspace.getDefault.libLoader + + // Load shared native libraries the backend libraries depend on. + SharedLibraryLoader.load(conf, loader) + + // Load backend libraries. val libPath = conf.get(GlutenConfig.GLUTEN_LIB_PATH, StringUtils.EMPTY) if (StringUtils.isNotBlank(libPath)) { // Path based load. Ignore all other loadees. JniLibLoader.loadFromPath(libPath, false) @@ -176,11 +120,11 @@ class VeloxListenerApi extends ListenerApi { loader.mapAndLoad(VeloxBackend.BACKEND_NAME, false) } + // Initial native backend with configurations. val parsed = GlutenConfigUtil.parseConfig(conf.getAll.toMap) NativeBackendInitializer.initializeBackend(parsed) - // inject backend-specific implementations to override spark classes - // FIXME: The following set instances twice in local mode? + // Inject backend-specific implementations to override spark classes. GlutenParquetWriterInjects.setInstance(new VeloxParquetWriterInjects()) GlutenOrcWriterInjects.setInstance(new VeloxOrcWriterInjects()) GlutenRowSplitter.setInstance(new VeloxRowSplitter()) @@ -191,4 +135,13 @@ class VeloxListenerApi extends ListenerApi { } } -object VeloxListenerApi {} +object VeloxListenerApi { + // TODO: Implement graceful shutdown and remove these flags. + // As spark conf may change when active Spark session is recreated. + private val driverInitialized: AtomicBoolean = new AtomicBoolean(false) + private val executorInitialized: AtomicBoolean = new AtomicBoolean(false) + + private def inLocalMode(conf: SparkConf): Boolean = { + SparkResourceUtil.isLocalMaster(conf) + } +} diff --git a/backends-velox/src/main/scala/org/apache/gluten/utils/SharedLibraryLoader.scala b/backends-velox/src/main/scala/org/apache/gluten/utils/SharedLibraryLoader.scala index 137da83c0980..1f3ca30de9f5 100755 --- a/backends-velox/src/main/scala/org/apache/gluten/utils/SharedLibraryLoader.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/utils/SharedLibraryLoader.scala @@ -16,8 +16,112 @@ */ package org.apache.gluten.utils +import org.apache.gluten.GlutenConfig +import org.apache.gluten.exception.GlutenException import org.apache.gluten.vectorized.JniLibLoader +import org.apache.spark.SparkConf + +import scala.sys.process._ + trait SharedLibraryLoader { def loadLib(loader: JniLibLoader): Unit } + +object SharedLibraryLoader { + def load(conf: SparkConf, jni: JniLibLoader): Unit = { + val shouldLoad = conf.getBoolean( + GlutenConfig.GLUTEN_LOAD_LIB_FROM_JAR, + GlutenConfig.GLUTEN_LOAD_LIB_FROM_JAR_DEFAULT) + if (!shouldLoad) { + return + } + val osName = System.getProperty("os.name") + if (osName.startsWith("Mac OS X") || osName.startsWith("macOS")) { + loadLibWithMacOS(jni) + } else { + loadLibWithLinux(conf, jni) + } + } + + private def loadLibWithLinux(conf: SparkConf, jni: JniLibLoader): Unit = { + val loader = find(conf) + loader.loadLib(jni) + } + + private def loadLibWithMacOS(jni: JniLibLoader): Unit = { + // Placeholder for loading shared libs on MacOS if user needs. + } + + private def find(conf: SparkConf): SharedLibraryLoader = { + val systemName = conf.getOption(GlutenConfig.GLUTEN_LOAD_LIB_OS) + val loader = if (systemName.isDefined) { + val systemVersion = conf.getOption(GlutenConfig.GLUTEN_LOAD_LIB_OS_VERSION) + if (systemVersion.isEmpty) { + throw new GlutenException( + s"${GlutenConfig.GLUTEN_LOAD_LIB_OS_VERSION} must be specified when specifies the " + + s"${GlutenConfig.GLUTEN_LOAD_LIB_OS}") + } + getForOS(systemName.get, systemVersion.get, "") + } else { + val system = "cat /etc/os-release".!! + val systemNamePattern = "^NAME=\"?(.*)\"?".r + val systemVersionPattern = "^VERSION=\"?(.*)\"?".r + val systemInfoLines = system.stripMargin.split("\n") + val systemNamePattern(systemName) = + systemInfoLines.find(_.startsWith("NAME=")).getOrElse("") + val systemVersionPattern(systemVersion) = + systemInfoLines.find(_.startsWith("VERSION=")).getOrElse("") + if (systemName.isEmpty || systemVersion.isEmpty) { + throw new GlutenException("Failed to get OS name and version info.") + } + getForOS(systemName, systemVersion, system) + } + loader + } + + private def getForOS( + systemName: String, + systemVersion: String, + system: String): SharedLibraryLoader = { + if (systemName.contains("Ubuntu") && systemVersion.startsWith("20.04")) { + new SharedLibraryLoaderUbuntu2004 + } else if (systemName.contains("Ubuntu") && systemVersion.startsWith("22.04")) { + new SharedLibraryLoaderUbuntu2204 + } else if (systemName.contains("CentOS") && systemVersion.startsWith("9")) { + new SharedLibraryLoaderCentos9 + } else if (systemName.contains("CentOS") && systemVersion.startsWith("8")) { + new SharedLibraryLoaderCentos8 + } else if (systemName.contains("CentOS") && systemVersion.startsWith("7")) { + new SharedLibraryLoaderCentos7 + } else if (systemName.contains("Alibaba Cloud Linux") && systemVersion.startsWith("3")) { + new SharedLibraryLoaderCentos8 + } else if (systemName.contains("Alibaba Cloud Linux") && systemVersion.startsWith("2")) { + new SharedLibraryLoaderCentos7 + } else if (systemName.contains("Anolis") && systemVersion.startsWith("8")) { + new SharedLibraryLoaderCentos8 + } else if (systemName.contains("Anolis") && systemVersion.startsWith("7")) { + new SharedLibraryLoaderCentos7 + } else if (system.contains("tencentos") && system.contains("2.4")) { + new SharedLibraryLoaderCentos7 + } else if (system.contains("tencentos") && system.contains("3.2")) { + new SharedLibraryLoaderCentos8 + } else if (systemName.contains("Red Hat") && systemVersion.startsWith("9")) { + new SharedLibraryLoaderCentos9 + } else if (systemName.contains("Red Hat") && systemVersion.startsWith("8")) { + new SharedLibraryLoaderCentos8 + } else if (systemName.contains("Red Hat") && systemVersion.startsWith("7")) { + new SharedLibraryLoaderCentos7 + } else if (systemName.contains("Debian") && systemVersion.startsWith("11")) { + new SharedLibraryLoaderDebian11 + } else if (systemName.contains("Debian") && systemVersion.startsWith("12")) { + new SharedLibraryLoaderDebian12 + } else { + throw new GlutenException( + s"Found unsupported OS($systemName, $systemVersion)! Currently, Gluten's Velox backend" + + " only supports Ubuntu 20.04/22.04, CentOS 7/8, " + + "Alibaba Cloud Linux 2/3 & Anolis 7/8, tencentos 2.4/3.2, RedHat 7/8, " + + "Debian 11/12.") + } + } +} diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala similarity index 99% rename from backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala rename to backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala index a0ea7d7267b4..fa7eae37b1c9 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala @@ -35,7 +35,7 @@ import java.util.concurrent.TimeUnit import scala.collection.JavaConverters -class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPlanHelper { +class MiscOperatorSuite extends VeloxWholeStageTransformerSuite with AdaptiveSparkPlanHelper { protected val rootPath: String = getClass.getResource("/").getPath override protected val resourcePath: String = "/tpch-data-parquet-velox" @@ -2078,4 +2078,12 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla checkGlutenOperatorMatch[SortExecTransformer] } } + + // Enable the test after fixing https://github.com/apache/incubator-gluten/issues/6827 + ignore("Test round expression") { + val df1 = runQueryAndCompare("SELECT round(cast(0.5549999999999999 as double), 2)") { _ => } + checkLengthAndPlan(df1, 1) + val df2 = runQueryAndCompare("SELECT round(cast(0.19324999999999998 as double), 2)") { _ => } + checkLengthAndPlan(df2, 1) + } } diff --git a/cpp-ch/clickhouse.version b/cpp-ch/clickhouse.version index 7c93bc1240ce..d41875c54d7d 100644 --- a/cpp-ch/clickhouse.version +++ b/cpp-ch/clickhouse.version @@ -1,3 +1,3 @@ CH_ORG=Kyligence -CH_BRANCH=rebase_ch/20240815 -CH_COMMIT=d87dbba64fc +CH_BRANCH=rebase_ch/20240817 +CH_COMMIT=ed191291681 diff --git a/cpp-ch/local-engine/Common/CHUtil.cpp b/cpp-ch/local-engine/Common/CHUtil.cpp index 0409b66bd920..8e07eea011b8 100644 --- a/cpp-ch/local-engine/Common/CHUtil.cpp +++ b/cpp-ch/local-engine/Common/CHUtil.cpp @@ -979,6 +979,7 @@ void BackendInitializerUtil::init(const std::string_view plan) // Init the table metadata cache map StorageMergeTreeFactory::init_cache_map(); + JobScheduler::initialize(SerializedPlanParser::global_context); CacheManager::initialize(SerializedPlanParser::global_context); std::call_once( diff --git a/cpp-ch/local-engine/Common/ConcurrentMap.h b/cpp-ch/local-engine/Common/ConcurrentMap.h index 1719d9b255ea..2db35102215a 100644 --- a/cpp-ch/local-engine/Common/ConcurrentMap.h +++ b/cpp-ch/local-engine/Common/ConcurrentMap.h @@ -16,7 +16,7 @@ */ #pragma once -#include +#include #include namespace local_engine diff --git a/cpp-ch/local-engine/Common/GlutenConfig.h b/cpp-ch/local-engine/Common/GlutenConfig.h index 84744dab21b8..ac82b0fff03a 100644 --- a/cpp-ch/local-engine/Common/GlutenConfig.h +++ b/cpp-ch/local-engine/Common/GlutenConfig.h @@ -183,5 +183,19 @@ struct MergeTreeConfig return config; } }; + +struct GlutenJobSchedulerConfig +{ + inline static const String JOB_SCHEDULER_MAX_THREADS = "job_scheduler_max_threads"; + + size_t job_scheduler_max_threads = 10; + + static GlutenJobSchedulerConfig loadFromContext(DB::ContextPtr context) + { + GlutenJobSchedulerConfig config; + config.job_scheduler_max_threads = context->getConfigRef().getUInt64(JOB_SCHEDULER_MAX_THREADS, 10); + return config; + } +}; } diff --git a/cpp-ch/local-engine/Storages/Cache/CacheManager.cpp b/cpp-ch/local-engine/Storages/Cache/CacheManager.cpp index d2c7b06810db..a97f0c72ada4 100644 --- a/cpp-ch/local-engine/Storages/Cache/CacheManager.cpp +++ b/cpp-ch/local-engine/Storages/Cache/CacheManager.cpp @@ -26,12 +26,13 @@ #include #include #include -#include #include #include #include #include +#include + namespace DB { namespace ErrorCodes @@ -49,6 +50,16 @@ extern const Metric LocalThreadScheduled; namespace local_engine { + +jclass CacheManager::cache_result_class = nullptr; +jmethodID CacheManager::cache_result_constructor = nullptr; + +void CacheManager::initJNI(JNIEnv * env) +{ + cache_result_class = CreateGlobalClassReference(env, "Lorg/apache/gluten/execution/CacheResult;"); + cache_result_constructor = GetMethodID(env, cache_result_class, "", "(ILjava/lang/String;)V"); +} + CacheManager & CacheManager::instance() { static CacheManager cache_manager; @@ -59,13 +70,6 @@ void CacheManager::initialize(DB::ContextMutablePtr context_) { auto & manager = instance(); manager.context = context_; - manager.thread_pool = std::make_unique( - CurrentMetrics::LocalThread, - CurrentMetrics::LocalThreadActive, - CurrentMetrics::LocalThreadScheduled, - manager.context->getConfigRef().getInt("cache_sync_max_threads", 10), - 0, - 0); } struct CacheJobContext @@ -73,17 +77,16 @@ struct CacheJobContext MergeTreeTable table; }; -void CacheManager::cachePart(const MergeTreeTable& table, const MergeTreePart& part, const std::unordered_set & columns, std::shared_ptr latch) +Task CacheManager::cachePart(const MergeTreeTable& table, const MergeTreePart& part, const std::unordered_set & columns) { CacheJobContext job_context{table}; job_context.table.parts.clear(); job_context.table.parts.push_back(part); job_context.table.snapshot_id = ""; - auto job = [job_detail = job_context, context = this->context, read_columns = columns, latch = latch]() + Task task = [job_detail = job_context, context = this->context, read_columns = columns]() { try { - SCOPE_EXIT({ if (latch) latch->count_down();}); auto storage = MergeTreeRelParser::parseStorage(job_detail.table, context, true); auto storage_snapshot = std::make_shared(*storage, storage->getInMemoryMetadataPtr()); NamesAndTypesList names_and_types_list; @@ -113,8 +116,7 @@ void CacheManager::cachePart(const MergeTreeTable& table, const MergeTreePart& p PullingPipelineExecutor executor(pipeline); while (true) { - Chunk chunk; - if (!executor.pull(chunk)) + if (Chunk chunk; !executor.pull(chunk)) break; } LOG_INFO(getLogger("CacheManager"), "Load cache of table {}.{} part {} success.", job_detail.table.database, job_detail.table.table, job_detail.table.parts.front().name); @@ -122,22 +124,58 @@ void CacheManager::cachePart(const MergeTreeTable& table, const MergeTreePart& p catch (std::exception& e) { LOG_ERROR(getLogger("CacheManager"), "Load cache of table {}.{} part {} failed.\n {}", job_detail.table.database, job_detail.table.table, job_detail.table.parts.front().name, e.what()); + std::rethrow_exception(std::current_exception()); } }; LOG_INFO(getLogger("CacheManager"), "Loading cache of table {}.{} part {}", job_context.table.database, job_context.table.table, job_context.table.parts.front().name); - thread_pool->scheduleOrThrowOnError(std::move(job)); + return std::move(task); } -void CacheManager::cacheParts(const String& table_def, const std::unordered_set& columns, bool async) +JobId CacheManager::cacheParts(const String& table_def, const std::unordered_set& columns) { auto table = parseMergeTreeTableString(table_def); - std::shared_ptr latch = nullptr; - if (!async) latch = std::make_shared(table.parts.size()); + JobId id = toString(UUIDHelpers::generateV4()); + Job job(id); for (const auto & part : table.parts) { - cachePart(table, part, columns, latch); + job.addTask(cachePart(table, part, columns)); + } + auto& scheduler = JobScheduler::instance(); + scheduler.scheduleJob(std::move(job)); + return id; +} + +jobject CacheManager::getCacheStatus(JNIEnv * env, const String & jobId) +{ + auto& scheduler = JobScheduler::instance(); + auto job_status = scheduler.getJobSatus(jobId); + int status = 0; + String message; + if (job_status.has_value()) + { + switch (job_status.value().status) + { + case JobSatus::RUNNING: + status = 0; + break; + case JobSatus::FINISHED: + status = 1; + break; + case JobSatus::FAILED: + status = 2; + for (const auto & msg : job_status->messages) + { + message.append(msg); + message.append(";"); + } + break; + } + } + else + { + status = 2; + message = fmt::format("job {} not found", jobId); } - if (latch) - latch->wait(); + return env->NewObject(cache_result_class, cache_result_constructor, status, charTojstring(env, message.c_str())); } } \ No newline at end of file diff --git a/cpp-ch/local-engine/Storages/Cache/CacheManager.h b/cpp-ch/local-engine/Storages/Cache/CacheManager.h index a303b7b7fc63..b88a3ea03e4e 100644 --- a/cpp-ch/local-engine/Storages/Cache/CacheManager.h +++ b/cpp-ch/local-engine/Storages/Cache/CacheManager.h @@ -16,29 +16,32 @@ */ #pragma once #include -#include - +#include +#include namespace local_engine { struct MergeTreePart; struct MergeTreeTable; + + + /*** * Manage the cache of the MergeTree, mainly including meta.bin, data.bin, metadata.gluten */ class CacheManager { public: + static jclass cache_result_class; + static jmethodID cache_result_constructor; + static void initJNI(JNIEnv* env); + static CacheManager & instance(); static void initialize(DB::ContextMutablePtr context); - void cachePart(const MergeTreeTable& table, const MergeTreePart& part, const std::unordered_set& columns, std::shared_ptr latch = nullptr); - void cacheParts(const String& table_def, const std::unordered_set& columns, bool async = true); + Task cachePart(const MergeTreeTable& table, const MergeTreePart& part, const std::unordered_set& columns); + JobId cacheParts(const String& table_def, const std::unordered_set& columns); + static jobject getCacheStatus(JNIEnv * env, const String& jobId); private: CacheManager() = default; - - std::unique_ptr thread_pool; DB::ContextMutablePtr context; - std::unordered_map policy_to_disk; - std::unordered_map disk_to_metadisk; - std::unordered_map policy_to_cache; }; } \ No newline at end of file diff --git a/cpp-ch/local-engine/Storages/Cache/JobScheduler.cpp b/cpp-ch/local-engine/Storages/Cache/JobScheduler.cpp new file mode 100644 index 000000000000..6a43ad644433 --- /dev/null +++ b/cpp-ch/local-engine/Storages/Cache/JobScheduler.cpp @@ -0,0 +1,163 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#include "JobScheduler.h" + +#include +#include +#include +#include + +namespace DB +{ +namespace ErrorCodes +{ +extern const int BAD_ARGUMENTS; +} +} + +namespace CurrentMetrics +{ +extern const Metric LocalThread; +extern const Metric LocalThreadActive; +extern const Metric LocalThreadScheduled; +} + +namespace local_engine +{ +std::shared_ptr global_job_scheduler = nullptr; + +void JobScheduler::initialize(DB::ContextPtr context) +{ + auto config = GlutenJobSchedulerConfig::loadFromContext(context); + instance().thread_pool = std::make_unique( + CurrentMetrics::LocalThread, + CurrentMetrics::LocalThreadActive, + CurrentMetrics::LocalThreadScheduled, + config.job_scheduler_max_threads, + 0, + 0); + +} + +JobId JobScheduler::scheduleJob(Job&& job) +{ + cleanFinishedJobs(); + if (job_details.contains(job.id)) + { + throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "job {} exists.", job.id); + } + size_t task_num = job.tasks.size(); + auto job_id = job.id; + std::vector task_results; + task_results.reserve(task_num); + JobContext job_context = {std::move(job), std::make_unique(task_num), std::move(task_results)}; + { + std::lock_guard lock(job_details_mutex); + job_details.emplace(job_id, std::move(job_context)); + } + LOG_INFO(logger, "schedule job {}", job_id); + + auto & job_detail = job_details.at(job_id); + + for (auto & task : job_detail.job.tasks) + { + job_detail.task_results.emplace_back(TaskResult()); + auto & task_result = job_detail.task_results.back(); + thread_pool->scheduleOrThrow( + [&]() + { + SCOPE_EXIT({ + job_detail.remain_tasks->fetch_sub(1, std::memory_order::acquire); + if (job_detail.isFinished()) + { + addFinishedJob(job_detail.job.id); + } + }); + try + { + task(); + task_result.status = TaskResult::Status::SUCCESS; + } + catch (std::exception & e) + { + task_result.status = TaskResult::Status::FAILED; + task_result.message = e.what(); + } + }); + } + return job_id; +} + +std::optional JobScheduler::getJobSatus(const JobId & job_id) +{ + if (!job_details.contains(job_id)) + { + return std::nullopt; + } + std::optional res; + auto & job_context = job_details.at(job_id); + if (job_context.isFinished()) + { + std::vector messages; + for (auto & task_result : job_context.task_results) + { + if (task_result.status == TaskResult::Status::FAILED) + { + messages.push_back(task_result.message); + } + } + if (messages.empty()) + res = JobSatus::success(); + else + res= JobSatus::failed(messages); + } + else + res = JobSatus::running(); + return res; +} + +void JobScheduler::cleanupJob(const JobId & job_id) +{ + LOG_INFO(logger, "clean job {}", job_id); + job_details.erase(job_id); +} + +void JobScheduler::addFinishedJob(const JobId & job_id) +{ + std::lock_guard lock(finished_job_mutex); + auto job = std::make_pair(job_id, Stopwatch()); + finished_job.emplace_back(job); +} + +void JobScheduler::cleanFinishedJobs() +{ + std::lock_guard lock(finished_job_mutex); + for (auto it = finished_job.begin(); it != finished_job.end();) + { + // clean finished job after 5 minutes + if (it->second.elapsedSeconds() > 60 * 5) + { + cleanupJob(it->first); + it = finished_job.erase(it); + } + else + ++it; + } +} +} \ No newline at end of file diff --git a/cpp-ch/local-engine/Storages/Cache/JobScheduler.h b/cpp-ch/local-engine/Storages/Cache/JobScheduler.h new file mode 100644 index 000000000000..b5c2f601a92b --- /dev/null +++ b/cpp-ch/local-engine/Storages/Cache/JobScheduler.h @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once +#include +#include +#include +#include + +namespace local_engine +{ + +using JobId = String; +using Task = std::function; + +class Job +{ + friend class JobScheduler; +public: + explicit Job(const JobId& id) + : id(id) + { + } + + void addTask(Task&& task) + { + tasks.emplace_back(task); + } + +private: + JobId id; + std::vector tasks; +}; + + + +struct JobSatus +{ + enum Status + { + RUNNING, + FINISHED, + FAILED + }; + Status status; + std::vector messages; + + static JobSatus success() + { + return JobSatus{FINISHED}; + } + + static JobSatus running() + { + return JobSatus{RUNNING}; + } + + static JobSatus failed(const std::vector & messages) + { + return JobSatus{FAILED, messages}; + } +}; + +struct TaskResult +{ + enum Status + { + SUCCESS, + FAILED, + RUNNING + }; + Status status = RUNNING; + String message; +}; + +class JobContext +{ +public: + Job job; + std::unique_ptr remain_tasks = std::make_unique(); + std::vector task_results; + + bool isFinished() + { + return remain_tasks->load(std::memory_order::relaxed) == 0; + } +}; + +class JobScheduler +{ +public: + static JobScheduler& instance() + { + static JobScheduler global_job_scheduler; + return global_job_scheduler; + } + + static void initialize(DB::ContextPtr context); + + JobId scheduleJob(Job&& job); + + std::optional getJobSatus(const JobId& job_id); + + void cleanupJob(const JobId& job_id); + + void addFinishedJob(const JobId& job_id); + + void cleanFinishedJobs(); +private: + JobScheduler() = default; + std::unique_ptr thread_pool; + std::unordered_map job_details; + std::mutex job_details_mutex; + + std::vector> finished_job; + std::mutex finished_job_mutex; + LoggerPtr logger = getLogger("JobScheduler"); +}; +} diff --git a/cpp-ch/local-engine/local_engine_jni.cpp b/cpp-ch/local-engine/local_engine_jni.cpp index 828556b4abf6..3c3d6d4f89c2 100644 --- a/cpp-ch/local-engine/local_engine_jni.cpp +++ b/cpp-ch/local-engine/local_engine_jni.cpp @@ -163,6 +163,7 @@ JNIEXPORT jint JNI_OnLoad(JavaVM * vm, void * /*reserved*/) env, local_engine::SparkRowToCHColumn::spark_row_interator_class, "nextBatch", "()Ljava/nio/ByteBuffer;"); local_engine::BroadCastJoinBuilder::init(env); + local_engine::CacheManager::initJNI(env); local_engine::JNIUtils::vm = vm; return JNI_VERSION_1_8; @@ -1269,7 +1270,7 @@ JNIEXPORT void Java_org_apache_gluten_utils_TestExceptionUtils_generateNativeExc -JNIEXPORT void Java_org_apache_gluten_execution_CHNativeCacheManager_nativeCacheParts(JNIEnv * env, jobject, jstring table_, jstring columns_, jboolean async_) +JNIEXPORT jstring Java_org_apache_gluten_execution_CHNativeCacheManager_nativeCacheParts(JNIEnv * env, jobject, jstring table_, jstring columns_) { LOCAL_ENGINE_JNI_METHOD_START auto table_def = jstring2string(env, table_); @@ -1280,10 +1281,17 @@ JNIEXPORT void Java_org_apache_gluten_execution_CHNativeCacheManager_nativeCache { column_set.insert(col); } - local_engine::CacheManager::instance().cacheParts(table_def, column_set, async_); - LOCAL_ENGINE_JNI_METHOD_END(env, ); + auto id = local_engine::CacheManager::instance().cacheParts(table_def, column_set); + return local_engine::charTojstring(env, id.c_str()); + LOCAL_ENGINE_JNI_METHOD_END(env, nullptr); } +JNIEXPORT jobject Java_org_apache_gluten_execution_CHNativeCacheManager_nativeGetCacheStatus(JNIEnv * env, jobject, jstring id) +{ + LOCAL_ENGINE_JNI_METHOD_START + return local_engine::CacheManager::instance().getCacheStatus(env, jstring2string(env, id)); + LOCAL_ENGINE_JNI_METHOD_END(env, nullptr); +} #ifdef __cplusplus } diff --git a/cpp-ch/local-engine/tests/data/68131.parquet b/cpp-ch/local-engine/tests/data/68131.parquet new file mode 100644 index 000000000000..169f6152003d Binary files /dev/null and b/cpp-ch/local-engine/tests/data/68131.parquet differ diff --git a/cpp-ch/local-engine/tests/gtest_clickhouse_pr_verify.cpp b/cpp-ch/local-engine/tests/gtest_clickhouse_pr_verify.cpp index 9e4165d90437..5b5797ed7d21 100644 --- a/cpp-ch/local-engine/tests/gtest_clickhouse_pr_verify.cpp +++ b/cpp-ch/local-engine/tests/gtest_clickhouse_pr_verify.cpp @@ -34,8 +34,8 @@ TEST(Clickhouse, PR54881) { const auto context1 = DB::Context::createCopy(SerializedPlanParser::global_context); // context1->setSetting("enable_named_columns_in_function_tuple", DB::Field(true)); - auto settingxs = context1->getSettingsRef(); - EXPECT_FALSE(settingxs.enable_named_columns_in_function_tuple) << "GLUTEN NEED set enable_named_columns_in_function_tuple to false"; + auto settings = context1->getSettingsRef(); + EXPECT_FALSE(settings.enable_named_columns_in_function_tuple) << "GLUTEN NEED set enable_named_columns_in_function_tuple to false"; const std::string split_template = R"({"items":[{"uriFile":"{replace_local_files}","partitionIndex":"0","length":"1529","parquet":{},"schema":{},"metadataColumns":[{}]}]})"; @@ -112,6 +112,26 @@ TEST(Clickhouse, PR68135) const auto plan = local_engine::JsonStringToMessage( {reinterpret_cast(gresource_embedded_pr_68135_jsonData), gresource_embedded_pr_68135_jsonSize}); + auto local_executor = parser.createExecutor(plan); + EXPECT_TRUE(local_executor->hasNext()); + const Block & x = *local_executor->nextColumnar(); + debug::headBlock(x); +} + +INCBIN(resource_embedded_pr_68131_json, SOURCE_DIR "/utils/extern-local-engine/tests/json/clickhouse_pr_68131.json"); +TEST(Clickhouse, PR68131) +{ + const std::string split_template + = R"({"items":[{"uriFile":"{replace_local_files}","partitionIndex":"0","length":"289","parquet":{},"schema":{},"metadataColumns":[{}]}]})"; + const std::string split + = replaceLocalFilesWildcards(split_template, GLUTEN_DATA_DIR("/utils/extern-local-engine/tests/data/68131.parquet")); + + SerializedPlanParser parser(SerializedPlanParser::global_context); + parser.addSplitInfo(local_engine::JsonStringToBinary(split)); + + const auto plan = local_engine::JsonStringToMessage( + {reinterpret_cast(gresource_embedded_pr_68131_jsonData), gresource_embedded_pr_68131_jsonSize}); + auto local_executor = parser.createExecutor(plan); EXPECT_TRUE(local_executor->hasNext()); const Block & x = *local_executor->nextColumnar(); diff --git a/cpp-ch/local-engine/tests/json/clickhouse_pr_68131.json b/cpp-ch/local-engine/tests/json/clickhouse_pr_68131.json new file mode 100644 index 000000000000..0add2092b817 --- /dev/null +++ b/cpp-ch/local-engine/tests/json/clickhouse_pr_68131.json @@ -0,0 +1,95 @@ +{ + "extensions": [ + { + "extensionFunction": { + "name": "is_not_null:list" + } + } + ], + "relations": [ + { + "root": { + "input": { + "filter": { + "common": { + "direct": {} + }, + "input": { + "read": { + "common": { + "direct": {} + }, + "baseSchema": { + "names": [ + "f" + ], + "struct": { + "types": [ + { + "list": { + "type": { + "i32": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "nullability": "NULLABILITY_NULLABLE" + } + } + ] + }, + "columnTypes": [ + "NORMAL_COL" + ] + }, + "advancedExtension": { + "optimization": { + "@type": "type.googleapis.com/google.protobuf.StringValue", + "value": "isMergeTree=0\n" + } + } + } + }, + "condition": { + "scalarFunction": { + "outputType": { + "bool": { + "nullability": "NULLABILITY_REQUIRED" + } + }, + "arguments": [ + { + "value": { + "selection": { + "directReference": { + "structField": {} + } + } + } + } + ] + } + } + } + }, + "names": [ + "f#0" + ], + "outputSchema": { + "types": [ + { + "list": { + "type": { + "i32": { + "nullability": "NULLABILITY_NULLABLE" + } + }, + "nullability": "NULLABILITY_REQUIRED" + } + } + ], + "nullability": "NULLABILITY_REQUIRED" + } + } + } + ] +} \ No newline at end of file diff --git a/cpp/core/config/GlutenConfig.cc b/cpp/core/config/GlutenConfig.cc index fa04ecfa4e5c..bc6ad1cbe859 100644 --- a/cpp/core/config/GlutenConfig.cc +++ b/cpp/core/config/GlutenConfig.cc @@ -15,13 +15,26 @@ * limitations under the License. */ +#include #include - +#include #include "compute/ProtobufUtils.h" #include "config.pb.h" #include "jni/JniError.h" +namespace { + +std::optional getRedactionRegex(const std::unordered_map& conf) { + auto it = conf.find(gluten::kSparkRedactionRegex); + if (it != conf.end()) { + return boost::regex(it->second); + } + return std::nullopt; +} +} // namespace + namespace gluten { + std::unordered_map parseConfMap(JNIEnv* env, const uint8_t* planData, const int32_t planDataLength) { std::unordered_map sparkConfs; @@ -37,9 +50,17 @@ parseConfMap(JNIEnv* env, const uint8_t* planData, const int32_t planDataLength) std::string printConfig(const std::unordered_map& conf) { std::ostringstream oss; oss << std::endl; - for (auto& [k, v] : conf) { - oss << " [" << k << ", " << v << "]\n"; + + auto redactionRegex = getRedactionRegex(conf); + + for (const auto& [k, v] : conf) { + if (redactionRegex && boost::regex_match(k, *redactionRegex)) { + oss << " [" << k << ", " << kSparkRedactionString << "]\n"; + } else { + oss << " [" << k << ", " << v << "]\n"; + } } return oss.str(); } + } // namespace gluten diff --git a/cpp/core/config/GlutenConfig.h b/cpp/core/config/GlutenConfig.h index 057d85930d2a..31318ff0aa0c 100644 --- a/cpp/core/config/GlutenConfig.h +++ b/cpp/core/config/GlutenConfig.h @@ -66,6 +66,9 @@ const std::string kShuffleCompressionCodecBackend = "spark.gluten.sql.columnar.s const std::string kQatBackendName = "qat"; const std::string kIaaBackendName = "iaa"; +const std::string kSparkRedactionRegex = "spark.redaction.regex"; +const std::string kSparkRedactionString = "*********(redacted)"; + std::unordered_map parseConfMap(JNIEnv* env, const uint8_t* planData, const int32_t planDataLength); diff --git a/ep/build-velox/src/get_velox.sh b/ep/build-velox/src/get_velox.sh index f857ceada653..ba967e22f666 100755 --- a/ep/build-velox/src/get_velox.sh +++ b/ep/build-velox/src/get_velox.sh @@ -17,7 +17,7 @@ set -exu VELOX_REPO=https://github.com/oap-project/velox.git -VELOX_BRANCH=2024_08_16 +VELOX_BRANCH=2024_08_19 VELOX_HOME="" OS=`uname -s` diff --git a/gluten-celeborn/clickhouse/pom.xml b/gluten-celeborn/clickhouse/pom.xml index 284a8f57282a..9e64e77ce6ea 100755 --- a/gluten-celeborn/clickhouse/pom.xml +++ b/gluten-celeborn/clickhouse/pom.xml @@ -148,6 +148,38 @@ ${hadoop.version} test + + org.apache.arrow + arrow-memory-core + ${arrow.version} + provided + + + io.netty + netty-common + + + io.netty + netty-buffer + + + + + org.apache.arrow + arrow-vector + ${arrow.version} + provided + + + io.netty + netty-common + + + io.netty + netty-buffer + + + diff --git a/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornColumnarBatchSerializer.scala b/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornColumnarBatchSerializer.scala index 5072ce6a1a2e..3b8e92bfe1d2 100644 --- a/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornColumnarBatchSerializer.scala +++ b/gluten-celeborn/clickhouse/src/main/scala/org/apache/spark/shuffle/CHCelebornColumnarBatchSerializer.scala @@ -62,7 +62,9 @@ private class CHCelebornColumnarBatchSerializerInstance( private lazy val compressionCodec = GlutenShuffleUtils.getCompressionCodec(conf) private lazy val capitalizedCompressionCodec = compressionCodec.toUpperCase(Locale.ROOT) private lazy val compressionLevel = - GlutenShuffleUtils.getCompressionLevel(conf, compressionCodec, + GlutenShuffleUtils.getCompressionLevel( + conf, + compressionCodec, GlutenConfig.getConf.columnarShuffleCodecBackend.orNull) override def deserializeStream(in: InputStream): DeserializationStream = { diff --git a/gluten-core/src/main/scala/org/apache/spark/util/SparkDirectoryUtil.scala b/gluten-core/src/main/scala/org/apache/spark/util/SparkDirectoryUtil.scala index fbc59edfdd6b..833575178c66 100644 --- a/gluten-core/src/main/scala/org/apache/spark/util/SparkDirectoryUtil.scala +++ b/gluten-core/src/main/scala/org/apache/spark/util/SparkDirectoryUtil.scala @@ -79,7 +79,7 @@ object SparkDirectoryUtil extends Logging { return } if (INSTANCE.roots.toSet != roots.toSet) { - logWarning( + throw new IllegalArgumentException( s"Reinitialize SparkDirectoryUtil with different root dirs: old: ${INSTANCE.ROOTS .mkString("Array(", ", ", ")")}, new: ${roots.mkString("Array(", ", ", ")")}" ) diff --git a/gluten-core/src/main/scala/org/apache/spark/util/SparkResourceUtil.scala b/gluten-core/src/main/scala/org/apache/spark/util/SparkResourceUtil.scala index b16d43de5d68..f8c791fe1374 100644 --- a/gluten-core/src/main/scala/org/apache/spark/util/SparkResourceUtil.scala +++ b/gluten-core/src/main/scala/org/apache/spark/util/SparkResourceUtil.scala @@ -76,4 +76,8 @@ object SparkResourceUtil extends Logging { val taskCores = conf.getInt("spark.task.cpus", 1) executorCores / taskCores } + + def isLocalMaster(conf: SparkConf): Boolean = { + Utils.isLocalMaster(conf) + } } diff --git a/gluten-core/src/test/scala/org/apache/gluten/test/FallbackUtil.scala b/gluten-core/src/test/scala/org/apache/gluten/test/FallbackUtil.scala index d2626ab275ce..3d26dd16c4eb 100644 --- a/gluten-core/src/test/scala/org/apache/gluten/test/FallbackUtil.scala +++ b/gluten-core/src/test/scala/org/apache/gluten/test/FallbackUtil.scala @@ -20,11 +20,11 @@ import org.apache.gluten.extension.GlutenPlan import org.apache.spark.internal.Logging import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, QueryStageExec} +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, AQEShuffleReadExec, QueryStageExec} import org.apache.spark.sql.execution.exchange.ReusedExchangeExec /** - * attention: if AQE is enable,This method will only be executed correctly after the execution plan + * attention: if AQE is enabled,This method will only be executed correctly after the execution plan * is fully determined */ @@ -42,10 +42,14 @@ object FallbackUtil extends Logging with AdaptiveSparkPlanHelper { true case WholeStageCodegenExec(_) => true + case ColumnarInputAdapter(_) => + true case InputAdapter(_) => true case AdaptiveSparkPlanExec(_, _, _, _, _) => true + case AQEShuffleReadExec(_, _) => + true case _: LimitExec => true // for ut @@ -57,30 +61,15 @@ object FallbackUtil extends Logging with AdaptiveSparkPlanHelper { true case _: ReusedExchangeExec => true - case p: SparkPlan if p.supportsColumnar => - true case _ => false } } def hasFallback(plan: SparkPlan): Boolean = { - var fallbackOperator: Seq[SparkPlan] = null - if (plan.isInstanceOf[AdaptiveSparkPlanExec]) { - fallbackOperator = collectWithSubqueries(plan) { - case plan if !plan.isInstanceOf[GlutenPlan] && !skip(plan) => - plan - } - } else { - fallbackOperator = plan.collectWithSubqueries { - case plan if !plan.isInstanceOf[GlutenPlan] && !skip(plan) => - plan - } - } - - if (fallbackOperator.nonEmpty) { - fallbackOperator.foreach(operator => log.info(s"gluten fallback operator:{$operator}")) - } + val fallbackOperator = collectWithSubqueries(plan) { case plan => plan }.filterNot( + plan => plan.isInstanceOf[GlutenPlan] || skip(plan)) + fallbackOperator.foreach(operator => log.info(s"gluten fallback operator:{$operator}")) fallbackOperator.nonEmpty } } diff --git a/gluten-ut/pom.xml b/gluten-ut/pom.xml index 90644b832bf8..a016eccaed20 100644 --- a/gluten-ut/pom.xml +++ b/gluten-ut/pom.xml @@ -31,7 +31,7 @@ gluten-ut pom - Gluten Unit Test + Gluten Unit Test Parent