From 9059ff685f37367af8d84c6cc26bae05ec2a225f Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Tue, 3 Dec 2024 09:36:15 +0800 Subject: [PATCH 1/5] [GLUTEN-8060][CORE] GlutenShuffleManager as a registry of shuffle managers (#8084) --- .../spark/shuffle/GlutenShuffleManager.scala | 71 ++++ .../org/apache/spark/shuffle/LookupKey.scala | 27 ++ .../spark/shuffle/ShuffleManagerLookup.scala | 38 +++ .../shuffle/ShuffleManagerRegistry.scala | 94 ++++++ .../spark/shuffle/ShuffleManagerRouter.scala | 137 ++++++++ .../shuffle/GlutenShuffleManagerSuite.scala | 315 ++++++++++++++++++ 6 files changed, 682 insertions(+) create mode 100644 gluten-core/src/main/scala/org/apache/spark/shuffle/GlutenShuffleManager.scala create mode 100644 gluten-core/src/main/scala/org/apache/spark/shuffle/LookupKey.scala create mode 100644 gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerLookup.scala create mode 100644 gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerRegistry.scala create mode 100644 gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerRouter.scala create mode 100644 gluten-core/src/test/scala/org/apache/spark/shuffle/GlutenShuffleManagerSuite.scala diff --git a/gluten-core/src/main/scala/org/apache/spark/shuffle/GlutenShuffleManager.scala b/gluten-core/src/main/scala/org/apache/spark/shuffle/GlutenShuffleManager.scala new file mode 100644 index 0000000000000..d38781675baf0 --- /dev/null +++ b/gluten-core/src/main/scala/org/apache/spark/shuffle/GlutenShuffleManager.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle + +import org.apache.spark.{ShuffleDependency, SparkConf, TaskContext} +import org.apache.spark.annotation.Experimental + +/** + * Shuffle manager that routes shuffle API calls to different shuffle managers registered by + * different backends. + * + * A SPIP may cause refactoring of this class in the future: + * https://issues.apache.org/jira/browse/SPARK-45792 + */ +@Experimental +class GlutenShuffleManager(conf: SparkConf, isDriver: Boolean) extends ShuffleManager { + private val routerBuilder = ShuffleManagerRegistry.get().newRouterBuilder(conf, isDriver) + + override def registerShuffle[K, V, C]( + shuffleId: Int, + dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { + routerBuilder.getOrBuild().registerShuffle(shuffleId, dependency) + } + + override def getWriter[K, V]( + handle: ShuffleHandle, + mapId: Long, + context: TaskContext, + metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { + routerBuilder.getOrBuild().getWriter(handle, mapId, context, metrics) + } + + override def getReader[K, C]( + handle: ShuffleHandle, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { + routerBuilder + .getOrBuild() + .getReader(handle, startMapIndex, endMapIndex, startPartition, endPartition, context, metrics) + } + + override def unregisterShuffle(shuffleId: Int): Boolean = { + routerBuilder.getOrBuild().unregisterShuffle(shuffleId) + } + + override def shuffleBlockResolver: ShuffleBlockResolver = { + routerBuilder.getOrBuild().shuffleBlockResolver + } + + override def stop(): Unit = { + routerBuilder.getOrBuild().stop() + } +} diff --git a/gluten-core/src/main/scala/org/apache/spark/shuffle/LookupKey.scala b/gluten-core/src/main/scala/org/apache/spark/shuffle/LookupKey.scala new file mode 100644 index 0000000000000..502dd92efe809 --- /dev/null +++ b/gluten-core/src/main/scala/org/apache/spark/shuffle/LookupKey.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle + +import org.apache.spark.ShuffleDependency + +/** + * Required during shuffle manager registration to determine whether the shuffle manager should be + * used for the particular shuffle dependency. + */ +trait LookupKey { + def accepts[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean +} diff --git a/gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerLookup.scala b/gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerLookup.scala new file mode 100644 index 0000000000000..8b060c98184fc --- /dev/null +++ b/gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerLookup.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle + +import org.apache.spark.ShuffleDependency + +private class ShuffleManagerLookup(all: Seq[(LookupKey, ShuffleManager)]) { + private val allReversed = all.reverse + + def findShuffleManager[K, V, C](dependency: ShuffleDependency[K, V, C]): ShuffleManager = { + this.synchronized { + // The latest shuffle manager registered will be looked up earlier. + allReversed.find(_._1.accepts(dependency)).map(_._2).getOrElse { + throw new IllegalStateException(s"No ShuffleManager found for $dependency") + } + } + } + + def all(): Seq[ShuffleManager] = { + this.synchronized { + all.map(_._2) + } + } +} diff --git a/gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerRegistry.scala b/gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerRegistry.scala new file mode 100644 index 0000000000000..4310054caa51e --- /dev/null +++ b/gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerRegistry.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle + +import org.apache.spark.SparkConf +import org.apache.spark.util.Utils + +import scala.collection.mutable + +class ShuffleManagerRegistry private[ShuffleManagerRegistry] { + import ShuffleManagerRegistry._ + private val all: mutable.Buffer[(LookupKey, String)] = mutable.Buffer() + private val routerBuilders: mutable.Buffer[RouterBuilder] = mutable.Buffer() + private val classDeDup: mutable.Set[String] = mutable.Set() + + def register(lookupKey: LookupKey, shuffleManagerClass: String): Unit = { + val clazz = Utils.classForName(shuffleManagerClass) + require( + !clazz.isAssignableFrom(classOf[GlutenShuffleManager]), + "It's not allowed to register GlutenShuffleManager recursively") + require( + classOf[ShuffleManager].isAssignableFrom(clazz), + s"Shuffle manager class to register is not an implementation of Spark ShuffleManager: " + + s"$shuffleManagerClass" + ) + require( + !classDeDup.contains(shuffleManagerClass), + s"Shuffle manager class already registered: $shuffleManagerClass") + this.synchronized { + classDeDup += shuffleManagerClass + all += lookupKey -> shuffleManagerClass + // Invalidate all shuffle managers cached in each alive router builder instances. + // Then, once the router builder is accessed, a new router will be forced to create. + routerBuilders.foreach(_.invalidateCache()) + } + } + + // Visible for testing + private[shuffle] def clear(): Unit = { + this.synchronized { + classDeDup.clear() + all.clear() + routerBuilders.foreach(_.invalidateCache()) + } + } + + private[shuffle] def newRouterBuilder(conf: SparkConf, isDriver: Boolean): RouterBuilder = + this.synchronized { + val out = new RouterBuilder(this, conf, isDriver) + routerBuilders += out + out + } +} + +object ShuffleManagerRegistry { + private val instance = new ShuffleManagerRegistry() + + def get(): ShuffleManagerRegistry = instance + + class RouterBuilder(registry: ShuffleManagerRegistry, conf: SparkConf, isDriver: Boolean) { + private var router: Option[ShuffleManagerRouter] = None + + private[ShuffleManagerRegistry] def invalidateCache(): Unit = synchronized { + router = None + } + + private[shuffle] def getOrBuild(): ShuffleManagerRouter = synchronized { + if (router.isEmpty) { + val instances = registry.all.map(key => key._1 -> instantiate(key._2, conf, isDriver)) + router = Some(new ShuffleManagerRouter(new ShuffleManagerLookup(instances.toSeq))) + } + router.get + } + + private def instantiate(clazz: String, conf: SparkConf, isDriver: Boolean): ShuffleManager = { + Utils + .instantiateSerializerOrShuffleManager[ShuffleManager](clazz, conf, isDriver) + } + } +} diff --git a/gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerRouter.scala b/gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerRouter.scala new file mode 100644 index 0000000000000..80aa9d8047856 --- /dev/null +++ b/gluten-core/src/main/scala/org/apache/spark/shuffle/ShuffleManagerRouter.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle +import org.apache.spark.{ShuffleDependency, TaskContext} +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.shuffle.MergedBlockMeta +import org.apache.spark.storage.{BlockId, ShuffleBlockBatchId, ShuffleBlockId, ShuffleMergedBlockId} + +/** The internal shuffle manager instance used by GlutenShuffleManager. */ +private class ShuffleManagerRouter(lookup: ShuffleManagerLookup) extends ShuffleManager { + import ShuffleManagerRouter._ + private val cache = new Cache() + private val resolver = new BlockResolver(cache) + + override def registerShuffle[K, V, C]( + shuffleId: Int, + dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { + val manager = lookup.findShuffleManager(dependency) + cache.store(shuffleId, manager).registerShuffle(shuffleId, dependency) + } + + override def getWriter[K, V]( + handle: ShuffleHandle, + mapId: Long, + context: TaskContext, + metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { + cache.get(handle.shuffleId).getWriter(handle, mapId, context, metrics) + } + + override def getReader[K, C]( + handle: ShuffleHandle, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { + cache + .get(handle.shuffleId) + .getReader(handle, startMapIndex, endMapIndex, startPartition, endPartition, context, metrics) + } + + override def unregisterShuffle(shuffleId: Int): Boolean = { + cache.remove(shuffleId).unregisterShuffle(shuffleId) + } + + override def shuffleBlockResolver: ShuffleBlockResolver = resolver + + override def stop(): Unit = { + assert(cache.size() == 0) + lookup.all().reverse.foreach(_.stop()) + } +} + +private object ShuffleManagerRouter { + private class Cache { + private val cache: java.util.Map[Int, ShuffleManager] = + new java.util.concurrent.ConcurrentHashMap() + + def store(shuffleId: Int, manager: ShuffleManager): ShuffleManager = { + cache.compute( + shuffleId, + (id, m) => { + assert(m == null, s"Shuffle manager was already cached for shuffle id: $id") + manager + }) + } + + def get(shuffleId: Int): ShuffleManager = { + val manager = cache.get(shuffleId) + assert(manager != null, s"Shuffle manager not registered for shuffle id: $shuffleId") + manager + } + + def remove(shuffleId: Int): ShuffleManager = { + val manager = cache.remove(shuffleId) + assert(manager != null, s"Shuffle manager not registered for shuffle id: $shuffleId") + manager + } + + def size(): Int = { + cache.size() + } + + def clear(): Unit = { + cache.clear() + } + } + + private class BlockResolver(cache: Cache) extends ShuffleBlockResolver { + override def getBlockData(blockId: BlockId, dirs: Option[Array[String]]): ManagedBuffer = { + val shuffleId = blockId match { + case id: ShuffleBlockId => + id.shuffleId + case batchId: ShuffleBlockBatchId => + batchId.shuffleId + case _ => + throw new IllegalArgumentException( + "GlutenShuffleManager: Unsupported shuffle block id: " + blockId) + } + cache.get(shuffleId).shuffleBlockResolver.getBlockData(blockId, dirs) + } + + override def getMergedBlockData( + blockId: ShuffleMergedBlockId, + dirs: Option[Array[String]]): Seq[ManagedBuffer] = { + val shuffleId = blockId.shuffleId + cache.get(shuffleId).shuffleBlockResolver.getMergedBlockData(blockId, dirs) + } + + override def getMergedBlockMeta( + blockId: ShuffleMergedBlockId, + dirs: Option[Array[String]]): MergedBlockMeta = { + val shuffleId = blockId.shuffleId + cache.get(shuffleId).shuffleBlockResolver.getMergedBlockMeta(blockId, dirs) + } + + override def stop(): Unit = { + throw new UnsupportedOperationException( + s"BlockResolver ${getClass.getSimpleName} doesn't need to be explicitly stopped") + } + } +} diff --git a/gluten-core/src/test/scala/org/apache/spark/shuffle/GlutenShuffleManagerSuite.scala b/gluten-core/src/test/scala/org/apache/spark/shuffle/GlutenShuffleManagerSuite.scala new file mode 100644 index 0000000000000..640fc0ab07961 --- /dev/null +++ b/gluten-core/src/test/scala/org/apache/spark/shuffle/GlutenShuffleManagerSuite.scala @@ -0,0 +1,315 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.shuffle + +import org.apache.spark.{Partitioner, ShuffleDependency, SparkConf, TaskContext} +import org.apache.spark.internal.config.SHUFFLE_MANAGER +import org.apache.spark.rdd.EmptyRDD +import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.sql.test.SharedSparkSession + +import java.util.concurrent.atomic.AtomicInteger + +import scala.collection.mutable + +class GlutenShuffleManagerSuite extends SharedSparkSession { + import GlutenShuffleManagerSuite._ + override protected def sparkConf: SparkConf = { + super.sparkConf + .set(SHUFFLE_MANAGER.key, classOf[GlutenShuffleManager].getName) + } + + override protected def afterEach(): Unit = { + val registry = ShuffleManagerRegistry.get() + registry.clear() + counter1.clear() + counter2.clear() + } + + test("register one") { + val registry = ShuffleManagerRegistry.get() + + registry.register( + new LookupKey { + override def accepts[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean = true + }, + classOf[ShuffleManager1].getName) + + val gm = spark.sparkContext.env.shuffleManager + assert(counter1.count("stop") == 0) + gm.stop() + assert(counter1.count("stop") == 1) + gm.stop() + gm.stop() + assert(counter1.count("stop") == 3) + } + + test("register two") { + val registry = ShuffleManagerRegistry.get() + + registry.register( + new LookupKey { + override def accepts[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean = true + }, + classOf[ShuffleManager1].getName) + registry.register( + new LookupKey { + override def accepts[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean = true + }, + classOf[ShuffleManager2].getName) + + val gm = spark.sparkContext.env.shuffleManager + assert(counter1.count("registerShuffle") == 0) + assert(counter2.count("registerShuffle") == 0) + // The statement calls #registerShuffle internally. + val dep = + new ShuffleDependency(new EmptyRDD[Product2[Any, Any]](spark.sparkContext), DummyPartitioner) + gm.unregisterShuffle(dep.shuffleId) + assert(counter1.count("registerShuffle") == 0) + assert(counter2.count("registerShuffle") == 1) + + assert(counter1.count("stop") == 0) + assert(counter2.count("stop") == 0) + gm.stop() + assert(counter1.count("stop") == 1) + assert(counter2.count("stop") == 1) + } + + test("register two - disordered registration") { + val registry = ShuffleManagerRegistry.get() + + registry.register( + new LookupKey { + override def accepts[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean = true + }, + classOf[ShuffleManager1].getName) + + val gm = spark.sparkContext.env.shuffleManager + assert(counter1.count("registerShuffle") == 0) + assert(counter2.count("registerShuffle") == 0) + val dep1 = + new ShuffleDependency(new EmptyRDD[Product2[Any, Any]](spark.sparkContext), DummyPartitioner) + gm.unregisterShuffle(dep1.shuffleId) + assert(counter1.count("registerShuffle") == 1) + assert(counter2.count("registerShuffle") == 0) + + registry.register( + new LookupKey { + override def accepts[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean = true + }, + classOf[ShuffleManager2].getName) + + // The statement calls #registerShuffle internally. + val dep2 = + new ShuffleDependency(new EmptyRDD[Product2[Any, Any]](spark.sparkContext), DummyPartitioner) + gm.unregisterShuffle(dep2.shuffleId) + assert(counter1.count("registerShuffle") == 1) + assert(counter2.count("registerShuffle") == 1) + + assert(counter1.count("stop") == 0) + assert(counter2.count("stop") == 0) + gm.stop() + assert(counter1.count("stop") == 1) + assert(counter2.count("stop") == 1) + } + + test("register two - with empty key") { + val registry = ShuffleManagerRegistry.get() + + registry.register( + new LookupKey { + override def accepts[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean = true + }, + classOf[ShuffleManager1].getName) + registry.register( + new LookupKey { + override def accepts[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean = false + }, + classOf[ShuffleManager2].getName) + + val gm = spark.sparkContext.env.shuffleManager + assert(counter1.count("registerShuffle") == 0) + assert(counter2.count("registerShuffle") == 0) + // The statement calls #registerShuffle internally. + val dep = + new ShuffleDependency(new EmptyRDD[Product2[Any, Any]](spark.sparkContext), DummyPartitioner) + gm.unregisterShuffle(dep.shuffleId) + assert(counter1.count("registerShuffle") == 1) + assert(counter2.count("registerShuffle") == 0) + } + + test("register recursively") { + val registry = ShuffleManagerRegistry.get() + + assertThrows[IllegalArgumentException]( + registry.register( + new LookupKey { + override def accepts[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean = true + }, + classOf[GlutenShuffleManager].getName)) + } + + test("register duplicated") { + val registry = ShuffleManagerRegistry.get() + + registry.register( + new LookupKey { + override def accepts[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean = true + }, + classOf[ShuffleManager1].getName) + assertThrows[IllegalArgumentException]( + registry.register( + new LookupKey { + override def accepts[K, V, C](dependency: ShuffleDependency[K, V, C]): Boolean = true + }, + classOf[ShuffleManager1].getName)) + } +} + +object GlutenShuffleManagerSuite { + private val counter1 = new InvocationCounter + private val counter2 = new InvocationCounter + + class ShuffleManager1(conf: SparkConf) extends ShuffleManager { + private val delegate = new SortShuffleManager(conf) + private val counter = counter1 + override def registerShuffle[K, V, C]( + shuffleId: Int, + dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { + counter.increment("registerShuffle") + delegate.registerShuffle(shuffleId, dependency) + } + + override def getWriter[K, V]( + handle: ShuffleHandle, + mapId: Long, + context: TaskContext, + metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { + counter.increment("getWriter") + delegate.getWriter(handle, mapId, context, metrics) + } + + override def getReader[K, C]( + handle: ShuffleHandle, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { + counter.increment("getReader") + delegate.getReader( + handle, + startMapIndex, + endMapIndex, + startPartition, + endPartition, + context, + metrics) + } + + override def unregisterShuffle(shuffleId: Int): Boolean = { + counter.increment("unregisterShuffle") + delegate.unregisterShuffle(shuffleId) + } + + override def shuffleBlockResolver: ShuffleBlockResolver = { + counter.increment("shuffleBlockResolver") + delegate.shuffleBlockResolver + } + + override def stop(): Unit = { + counter.increment("stop") + delegate.stop() + } + } + + class ShuffleManager2(conf: SparkConf, isDriver: Boolean) extends ShuffleManager { + private val delegate = new SortShuffleManager(conf) + private val counter = counter2 + override def registerShuffle[K, V, C]( + shuffleId: Int, + dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { + counter.increment("registerShuffle") + delegate.registerShuffle(shuffleId, dependency) + } + + override def getWriter[K, V]( + handle: ShuffleHandle, + mapId: Long, + context: TaskContext, + metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = { + counter.increment("getWriter") + delegate.getWriter(handle, mapId, context, metrics) + } + + override def getReader[K, C]( + handle: ShuffleHandle, + startMapIndex: Int, + endMapIndex: Int, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { + counter.increment("getReader") + delegate.getReader( + handle, + startMapIndex, + endMapIndex, + startPartition, + endPartition, + context, + metrics) + } + + override def unregisterShuffle(shuffleId: Int): Boolean = { + counter.increment("unregisterShuffle") + delegate.unregisterShuffle(shuffleId) + } + + override def shuffleBlockResolver: ShuffleBlockResolver = { + counter.increment("shuffleBlockResolver") + delegate.shuffleBlockResolver + } + + override def stop(): Unit = { + counter.increment("stop") + delegate.stop() + } + } + + private class InvocationCounter { + private val counter: mutable.Map[String, AtomicInteger] = mutable.Map() + + def increment(name: String): Unit = synchronized { + counter.getOrElseUpdate(name, new AtomicInteger()).incrementAndGet() + } + + def count(name: String): Int = { + counter.getOrElse(name, new AtomicInteger()).get() + } + + def clear(): Unit = { + counter.clear() + } + } + + private object DummyPartitioner extends Partitioner { + override def numPartitions: Int = 0 + override def getPartition(key: Any): Int = 0 + } +} From 6dd91bac6cc155810beb34415a6dee1a03069f6d Mon Sep 17 00:00:00 2001 From: Hongze Zhang Date: Tue, 3 Dec 2024 09:40:46 +0800 Subject: [PATCH 2/5] [GLUTEN-7745][VL] Incorporate SQL Union operator into Velox execution pipeline (#7842) --- .../backendsapi/clickhouse/CHMetricsApi.scala | 8 + .../apache/gluten/utils/GlutenURLDecoder.java | 2 +- .../backendsapi/velox/VeloxListenerApi.scala | 3 +- .../backendsapi/velox/VeloxMetricsApi.scala | 11 ++ .../backendsapi/velox/VeloxRuleApi.scala | 2 + .../apache/gluten/metrics/MetricsUtil.scala | 9 +- .../gluten/metrics/UnionMetricsUpdater.scala | 35 ++++ .../gluten/execution/MiscOperatorSuite.scala | 28 +++- .../VeloxOrcDataTypeValidationSuite.scala | 5 +- .../VeloxParquetDataTypeValidationSuite.scala | 5 +- cpp/velox/compute/WholeStageResultIterator.cc | 44 ++++- cpp/velox/substrait/SubstraitToVeloxPlan.cc | 46 ++++++ cpp/velox/substrait/SubstraitToVeloxPlan.h | 3 + .../SubstraitToVeloxPlanValidator.cc | 153 ++++++++++++++---- .../substrait/SubstraitToVeloxPlanValidator.h | 15 +- .../gluten/substrait/rel/RelBuilder.java | 17 ++ .../gluten/substrait/rel/SetRelNode.java | 62 +++++++ .../gluten/backendsapi/MetricsApi.scala | 4 + .../BasicPhysicalOperatorTransformer.scala | 3 +- .../execution/UnionExecTransformer.scala | 99 ++++++++++++ .../columnar/UnionTransformerRule.scala | 61 +++++++ .../WholeStageTransformerSuite.scala | 35 +++- .../org/apache/gluten/GlutenConfig.scala | 9 ++ 23 files changed, 604 insertions(+), 55 deletions(-) create mode 100644 backends-velox/src/main/scala/org/apache/gluten/metrics/UnionMetricsUpdater.scala create mode 100644 gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/SetRelNode.java create mode 100644 gluten-substrait/src/main/scala/org/apache/gluten/execution/UnionExecTransformer.scala create mode 100644 gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/UnionTransformerRule.scala diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala index 73b2d0f21101f..a0576a807b986 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHMetricsApi.scala @@ -450,6 +450,14 @@ class CHMetricsApi extends MetricsApi with Logging with LogLevelUtil { s"SampleTransformer metrics update is not supported in CH backend") } + override def genUnionTransformerMetrics(sparkContext: SparkContext): Map[String, SQLMetric] = + throw new UnsupportedOperationException( + "UnionExecTransformer metrics update is not supported in CH backend") + + override def genUnionTransformerMetricsUpdater(metrics: Map[String, SQLMetric]): MetricsUpdater = + throw new UnsupportedOperationException( + "UnionExecTransformer metrics update is not supported in CH backend") + def genWriteFilesTransformerMetrics(sparkContext: SparkContext): Map[String, SQLMetric] = Map( "physicalWrittenBytes" -> SQLMetrics.createMetric(sparkContext, "number of written bytes"), diff --git a/backends-velox/src/main/java/org/apache/gluten/utils/GlutenURLDecoder.java b/backends-velox/src/main/java/org/apache/gluten/utils/GlutenURLDecoder.java index 9228a2f860ae7..856ddf1597305 100644 --- a/backends-velox/src/main/java/org/apache/gluten/utils/GlutenURLDecoder.java +++ b/backends-velox/src/main/java/org/apache/gluten/utils/GlutenURLDecoder.java @@ -31,7 +31,7 @@ public class GlutenURLDecoder { *

Note: The World Wide Web Consortium * Recommendation states that UTF-8 should be used. Not doing so may introduce - * incompatibilites. + * incompatibilities. * * @param s the String to decode * @param enc The name of a supported character diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala index d29d3029709e8..3a82abe618330 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxListenerApi.scala @@ -31,6 +31,7 @@ import org.apache.spark.{HdfsConfGenerator, SparkConf, SparkContext} import org.apache.spark.api.plugin.PluginContext import org.apache.spark.internal.Logging import org.apache.spark.network.util.ByteUnit +import org.apache.spark.sql.execution.ColumnarCachedBatchSerializer import org.apache.spark.sql.execution.datasources.GlutenWriterColumnarRules import org.apache.spark.sql.execution.datasources.velox.{VeloxParquetWriterInjects, VeloxRowSplitter} import org.apache.spark.sql.expression.UDFResolver @@ -75,7 +76,7 @@ class VeloxListenerApi extends ListenerApi with Logging { if (conf.getBoolean(GlutenConfig.COLUMNAR_TABLE_CACHE_ENABLED.key, defaultValue = false)) { conf.set( StaticSQLConf.SPARK_CACHE_SERIALIZER.key, - "org.apache.spark.sql.execution.ColumnarCachedBatchSerializer") + classOf[ColumnarCachedBatchSerializer].getName) } // Static initializers for driver. diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxMetricsApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxMetricsApi.scala index e70e1d13bdfe5..934b680382eaf 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxMetricsApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxMetricsApi.scala @@ -582,4 +582,15 @@ class VeloxMetricsApi extends MetricsApi with Logging { override def genSampleTransformerMetricsUpdater(metrics: Map[String, SQLMetric]): MetricsUpdater = new SampleMetricsUpdater(metrics) + + override def genUnionTransformerMetrics(sparkContext: SparkContext): Map[String, SQLMetric] = Map( + "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input rows"), + "inputVectors" -> SQLMetrics.createMetric(sparkContext, "number of input vectors"), + "inputBytes" -> SQLMetrics.createSizeMetric(sparkContext, "number of input bytes"), + "wallNanos" -> SQLMetrics.createNanoTimingMetric(sparkContext, "time of union"), + "cpuCount" -> SQLMetrics.createMetric(sparkContext, "cpu wall time count") + ) + + override def genUnionTransformerMetricsUpdater(metrics: Map[String, SQLMetric]): MetricsUpdater = + new UnionMetricsUpdater(metrics) } diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala index 7841e6cd94b1b..7337be5737106 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala @@ -92,6 +92,7 @@ object VeloxRuleApi { c => HeuristicTransform.Single(validatorBuilder(c.glutenConf), rewrites, offloads)) // Legacy: Post-transform rules. + injector.injectPostTransform(_ => UnionTransformerRule()) injector.injectPostTransform(c => PartialProjectRule.apply(c.session)) injector.injectPostTransform(_ => RemoveNativeWriteFilesSortAndProject()) injector.injectPostTransform(c => RewriteTransformer.apply(c.session)) @@ -178,6 +179,7 @@ object VeloxRuleApi { // Gluten RAS: Post rules. injector.injectPostTransform(_ => RemoveTransitions) + injector.injectPostTransform(_ => UnionTransformerRule()) injector.injectPostTransform(c => PartialProjectRule.apply(c.session)) injector.injectPostTransform(_ => RemoveNativeWriteFilesSortAndProject()) injector.injectPostTransform(c => RewriteTransformer.apply(c.session)) diff --git a/backends-velox/src/main/scala/org/apache/gluten/metrics/MetricsUtil.scala b/backends-velox/src/main/scala/org/apache/gluten/metrics/MetricsUtil.scala index cd50d0b8e20c6..b8ef1620f905f 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/metrics/MetricsUtil.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/metrics/MetricsUtil.scala @@ -58,7 +58,8 @@ object MetricsUtil extends Logging { assert(t.children.size == 1, "MetricsUpdater.None can only be used on unary operator") treeifyMetricsUpdaters(t.children.head) case t: TransformSupport => - MetricsUpdaterTree(t.metricsUpdater(), t.children.map(treeifyMetricsUpdaters)) + // Reversed children order to match the traversal code. + MetricsUpdaterTree(t.metricsUpdater(), t.children.reverse.map(treeifyMetricsUpdaters)) case _ => MetricsUpdaterTree(MetricsUpdater.Terminate, Seq()) } @@ -233,6 +234,12 @@ object MetricsUtil extends Logging { operatorMetrics, metrics.getSingleMetrics, joinParamsMap.get(operatorIdx)) + case u: UnionMetricsUpdater => + // JoinRel outputs two suites of metrics respectively for hash build and hash probe. + // Therefore, fetch one more suite of metrics here. + operatorMetrics.add(metrics.getOperatorMetrics(curMetricsIdx)) + curMetricsIdx -= 1 + u.updateUnionMetrics(operatorMetrics) case hau: HashAggregateMetricsUpdater => hau.updateAggregationMetrics(operatorMetrics, aggParamsMap.get(operatorIdx)) case lu: LimitMetricsUpdater => diff --git a/backends-velox/src/main/scala/org/apache/gluten/metrics/UnionMetricsUpdater.scala b/backends-velox/src/main/scala/org/apache/gluten/metrics/UnionMetricsUpdater.scala new file mode 100644 index 0000000000000..9e91cf368c0a3 --- /dev/null +++ b/backends-velox/src/main/scala/org/apache/gluten/metrics/UnionMetricsUpdater.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.metrics + +import org.apache.spark.sql.execution.metric.SQLMetric + +class UnionMetricsUpdater(val metrics: Map[String, SQLMetric]) extends MetricsUpdater { + override def updateNativeMetrics(opMetrics: IOperatorMetrics): Unit = { + throw new UnsupportedOperationException() + } + + def updateUnionMetrics(unionMetrics: java.util.ArrayList[OperatorMetrics]): Unit = { + // Union was interpreted to LocalExchange + LocalPartition. Use metrics from LocalExchange. + val localExchangeMetrics = unionMetrics.get(0) + metrics("numInputRows") += localExchangeMetrics.inputRows + metrics("inputVectors") += localExchangeMetrics.inputVectors + metrics("inputBytes") += localExchangeMetrics.inputBytes + metrics("cpuCount") += localExchangeMetrics.cpuCount + metrics("wallNanos") += localExchangeMetrics.wallNanos + } +} diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala index 5cb2b652604d3..8063a5d12207a 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/MiscOperatorSuite.scala @@ -537,11 +537,37 @@ class MiscOperatorSuite extends VeloxWholeStageTransformerSuite with AdaptiveSpa |""".stripMargin) { df => { - getExecutedPlan(df).exists(plan => plan.find(_.isInstanceOf[ColumnarUnionExec]).isDefined) + assert( + getExecutedPlan(df).exists( + plan => plan.find(_.isInstanceOf[ColumnarUnionExec]).isDefined)) } } } + test("union_all two tables with known partitioning") { + withSQLConf(GlutenConfig.NATIVE_UNION_ENABLED.key -> "true") { + compareDfResultsAgainstVanillaSpark( + () => { + val df1 = spark.sql("select l_orderkey as orderkey from lineitem") + val df2 = spark.sql("select o_orderkey as orderkey from orders") + df1.repartition(5).union(df2.repartition(5)) + }, + compareResult = true, + checkGlutenOperatorMatch[UnionExecTransformer] + ) + + compareDfResultsAgainstVanillaSpark( + () => { + val df1 = spark.sql("select l_orderkey as orderkey from lineitem") + val df2 = spark.sql("select o_orderkey as orderkey from orders") + df1.repartition(5).union(df2.repartition(6)) + }, + compareResult = true, + checkGlutenOperatorMatch[ColumnarUnionExec] + ) + } + } + test("union_all three tables") { runQueryAndCompare(""" |select count(orderkey) from ( diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxOrcDataTypeValidationSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxOrcDataTypeValidationSuite.scala index 24e04f2dfce3e..6ac59ba4fa6bf 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxOrcDataTypeValidationSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxOrcDataTypeValidationSuite.scala @@ -255,7 +255,10 @@ class VeloxOrcDataTypeValidationSuite extends VeloxWholeStageTransformerSuite { |""".stripMargin) { df => { - assert(getExecutedPlan(df).exists(plan => plan.isInstanceOf[ColumnarUnionExec])) + assert( + getExecutedPlan(df).exists( + plan => + plan.isInstanceOf[ColumnarUnionExec] || plan.isInstanceOf[UnionExecTransformer])) } } diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxParquetDataTypeValidationSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxParquetDataTypeValidationSuite.scala index 57ca448fec79d..cb5614f39669c 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxParquetDataTypeValidationSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/VeloxParquetDataTypeValidationSuite.scala @@ -254,7 +254,10 @@ class VeloxParquetDataTypeValidationSuite extends VeloxWholeStageTransformerSuit |""".stripMargin) { df => { - assert(getExecutedPlan(df).exists(plan => plan.isInstanceOf[ColumnarUnionExec])) + assert( + getExecutedPlan(df).exists( + plan => + plan.isInstanceOf[ColumnarUnionExec] || plan.isInstanceOf[UnionExecTransformer])) } } diff --git a/cpp/velox/compute/WholeStageResultIterator.cc b/cpp/velox/compute/WholeStageResultIterator.cc index b6ecbd959f098..411c6c5636462 100644 --- a/cpp/velox/compute/WholeStageResultIterator.cc +++ b/cpp/velox/compute/WholeStageResultIterator.cc @@ -91,7 +91,7 @@ WholeStageResultIterator::WholeStageResultIterator( std::move(queryCtx), velox::exec::Task::ExecutionMode::kSerial); if (!task_->supportSerialExecutionMode()) { - throw std::runtime_error("Task doesn't support single thread execution: " + planNode->toString()); + throw std::runtime_error("Task doesn't support single threaded execution: " + planNode->toString()); } auto fileSystem = velox::filesystems::getFileSystem(spillDir, nullptr); GLUTEN_CHECK(fileSystem != nullptr, "File System for spilling is null!"); @@ -248,15 +248,47 @@ void WholeStageResultIterator::getOrderedNodeIds( const std::shared_ptr& planNode, std::vector& nodeIds) { bool isProjectNode = (std::dynamic_pointer_cast(planNode) != nullptr); + bool isLocalExchangeNode = (std::dynamic_pointer_cast(planNode) != nullptr); + bool isUnionNode = isLocalExchangeNode && + std::dynamic_pointer_cast(planNode)->type() == + velox::core::LocalPartitionNode::Type::kGather; const auto& sourceNodes = planNode->sources(); - for (const auto& sourceNode : sourceNodes) { + if (isProjectNode) { + GLUTEN_CHECK(sourceNodes.size() == 1, "Illegal state"); + const auto sourceNode = sourceNodes.at(0); // Filter over Project are mapped into FilterProject operator in Velox. // Metrics are all applied on Project node, and the metrics for Filter node // do not exist. - if (isProjectNode && std::dynamic_pointer_cast(sourceNode)) { + if (std::dynamic_pointer_cast(sourceNode)) { omittedNodeIds_.insert(sourceNode->id()); } getOrderedNodeIds(sourceNode, nodeIds); + nodeIds.emplace_back(planNode->id()); + return; + } + + if (isUnionNode) { + // FIXME: The whole metrics system in gluten-substrait is magic. Passing metrics trees through JNI with a trivial + // array is possible but requires for a solid design. Apparently we haven't had it. All the code requires complete + // rework. + // Union was interpreted as LocalPartition + LocalExchange + 2 fake projects as children in Velox. So we only fetch + // metrics from the root node. + std::vector> unionChildren{}; + for (const auto& source : planNode->sources()) { + const auto projectedChild = std::dynamic_pointer_cast(source); + GLUTEN_CHECK(projectedChild != nullptr, "Illegal state"); + const auto projectSources = projectedChild->sources(); + GLUTEN_CHECK(projectSources.size() == 1, "Illegal state"); + const auto projectSource = projectSources.at(0); + getOrderedNodeIds(projectSource, nodeIds); + } + nodeIds.emplace_back(planNode->id()); + return; + } + + for (const auto& sourceNode : sourceNodes) { + // Post-order traversal. + getOrderedNodeIds(sourceNode, nodeIds); } nodeIds.emplace_back(planNode->id()); } @@ -350,9 +382,9 @@ void WholeStageResultIterator::collectMetrics() { continue; } - const auto& status = planStats.at(nodeId); - // Add each operator status into metrics. - for (const auto& entry : status.operatorStats) { + const auto& stats = planStats.at(nodeId); + // Add each operator stats into metrics. + for (const auto& entry : stats.operatorStats) { const auto& second = entry.second; metrics_->get(Metrics::kInputRows)[metricIndex] = second->inputRows; metrics_->get(Metrics::kInputVectors)[metricIndex] = second->inputVectors; diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.cc b/cpp/velox/substrait/SubstraitToVeloxPlan.cc index 1efa7338796dc..3ceccca4a3de2 100644 --- a/cpp/velox/substrait/SubstraitToVeloxPlan.cc +++ b/cpp/velox/substrait/SubstraitToVeloxPlan.cc @@ -1043,6 +1043,50 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan( childNode); } +core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::SetRel& setRel) { + switch (setRel.op()) { + case ::substrait::SetRel_SetOp::SetRel_SetOp_SET_OP_UNION_ALL: { + std::vector children; + for (int32_t i = 0; i < setRel.inputs_size(); ++i) { + const auto& input = setRel.inputs(i); + children.push_back(toVeloxPlan(input)); + } + GLUTEN_CHECK(!children.empty(), "At least one source is required for Velox LocalPartition"); + + // Velox doesn't allow different field names in schemas of LocalPartitionNode's children. + // Add project nodes to unify the schemas. + const RowTypePtr outRowType = asRowType(children[0]->outputType()); + std::vector outNames; + for (int32_t colIdx = 0; colIdx < outRowType->size(); ++colIdx) { + const auto name = outRowType->childAt(colIdx)->name(); + outNames.push_back(name); + } + + std::vector projectedChildren; + for (int32_t i = 0; i < children.size(); ++i) { + const auto& child = children[i]; + const RowTypePtr& childRowType = child->outputType(); + std::vector expressions; + for (int32_t colIdx = 0; colIdx < outNames.size(); ++colIdx) { + const auto fa = + std::make_shared(childRowType->childAt(colIdx), childRowType->nameOf(colIdx)); + const auto cast = std::make_shared(outRowType->childAt(colIdx), fa, false); + expressions.push_back(cast); + } + auto project = std::make_shared(nextPlanNodeId(), outNames, expressions, child); + projectedChildren.push_back(project); + } + return std::make_shared( + nextPlanNodeId(), + core::LocalPartitionNode::Type::kGather, + std::make_shared(), + projectedChildren); + } + default: + throw GlutenException("Unsupported SetRel op: " + std::to_string(setRel.op())); + } +} + core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait::SortRel& sortRel) { auto childNode = convertSingleInput<::substrait::SortRel>(sortRel); auto [sortingKeys, sortingOrders] = processSortField(sortRel.sorts(), childNode->outputType()); @@ -1298,6 +1342,8 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait:: return toVeloxPlan(rel.write()); } else if (rel.has_windowgrouplimit()) { return toVeloxPlan(rel.windowgrouplimit()); + } else if (rel.has_set()) { + return toVeloxPlan(rel.set()); } else { VELOX_NYI("Substrait conversion not supported for Rel."); } diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.h b/cpp/velox/substrait/SubstraitToVeloxPlan.h index 51e50ce34767f..6121923df7871 100644 --- a/cpp/velox/substrait/SubstraitToVeloxPlan.h +++ b/cpp/velox/substrait/SubstraitToVeloxPlan.h @@ -84,6 +84,9 @@ class SubstraitToVeloxPlanConverter { /// Used to convert Substrait WindowGroupLimitRel into Velox PlanNode. core::PlanNodePtr toVeloxPlan(const ::substrait::WindowGroupLimitRel& windowGroupLimitRel); + /// Used to convert Substrait SetRel into Velox PlanNode. + core::PlanNodePtr toVeloxPlan(const ::substrait::SetRel& setRel); + /// Used to convert Substrait JoinRel into Velox PlanNode. core::PlanNodePtr toVeloxPlan(const ::substrait::JoinRel& joinRel); diff --git a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc index 3b74caf8ba5a9..9325fed3217c1 100644 --- a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc +++ b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc @@ -22,7 +22,6 @@ #include "TypeUtils.h" #include "udf/UdfLoader.h" #include "utils/Common.h" -#include "velox/core/ExpressionEvaluator.h" #include "velox/exec/Aggregate.h" #include "velox/expression/Expr.h" #include "velox/expression/SignatureBinder.h" @@ -30,7 +29,7 @@ namespace gluten { namespace { -static const char* extractFileName(const char* file) { +const char* extractFileName(const char* file) { return strrchr(file, '/') ? strrchr(file, '/') + 1 : file; } @@ -53,13 +52,13 @@ static const char* extractFileName(const char* file) { __FUNCTION__, \ reason)) -static const std::unordered_set kRegexFunctions = { +const std::unordered_set kRegexFunctions = { "regexp_extract", "regexp_extract_all", "regexp_replace", "rlike"}; -static const std::unordered_set kBlackList = { +const std::unordered_set kBlackList = { "split_part", "factorial", "concat_ws", @@ -70,32 +69,59 @@ static const std::unordered_set kBlackList = { "approx_percentile", "get_array_struct_fields", "map_from_arrays"}; - } // namespace -bool SubstraitToVeloxPlanValidator::validateInputTypes( +bool SubstraitToVeloxPlanValidator::parseVeloxType( const ::substrait::extensions::AdvancedExtension& extension, - std::vector& types) { + TypePtr& out) { + ::substrait::Type substraitType; // The input type is wrapped in enhancement. if (!extension.has_enhancement()) { LOG_VALIDATION_MSG("Input type is not wrapped in enhancement."); return false; } const auto& enhancement = extension.enhancement(); - ::substrait::Type inputType; - if (!enhancement.UnpackTo(&inputType)) { + if (!enhancement.UnpackTo(&substraitType)) { LOG_VALIDATION_MSG("Enhancement can't be unpacked to inputType."); return false; } - if (!inputType.has_struct_()) { - LOG_VALIDATION_MSG("Input type has no struct."); + + out = SubstraitParser::parseType(substraitType); + return true; +} + +bool SubstraitToVeloxPlanValidator::flattenVeloxType1(const TypePtr& type, std::vector& out) { + if (type->kind() != TypeKind::ROW) { + LOG_VALIDATION_MSG("Type is not a RowType."); + return false; + } + auto rowType = std::dynamic_pointer_cast(type); + if (!rowType) { + LOG_VALIDATION_MSG("Failed to cast to RowType."); return false; } + for (const auto& field : rowType->children()) { + out.emplace_back(field); + } + return true; +} - // Get the input types. - const auto& sTypes = inputType.struct_().types(); - for (const auto& sType : sTypes) { - types.emplace_back(SubstraitParser::parseType(sType)); +bool SubstraitToVeloxPlanValidator::flattenVeloxType2(const TypePtr& type, std::vector>& out) { + if (type->kind() != TypeKind::ROW) { + LOG_VALIDATION_MSG("Type is not a RowType."); + return false; + } + auto rowType = std::dynamic_pointer_cast(type); + if (!rowType) { + LOG_VALIDATION_MSG("Failed to cast to RowType."); + return false; + } + for (const auto& field : rowType->children()) { + std::vector inner; + if (!flattenVeloxType1(field, inner)) { + return false; + } + out.emplace_back(inner); } return true; } @@ -341,10 +367,11 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::WriteRel& writeR } // Validate input data type. + TypePtr inputRowType; std::vector types; if (writeRel.has_named_table()) { const auto& extension = writeRel.named_table().advanced_extension(); - if (!validateInputTypes(extension, types)) { + if (!parseVeloxType(extension, inputRowType) || !flattenVeloxType1(inputRowType, types)) { LOG_VALIDATION_MSG("Validation failed for input type validation in WriteRel."); return false; } @@ -380,12 +407,12 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::WriteRel& writeR } bool SubstraitToVeloxPlanValidator::validate(const ::substrait::FetchRel& fetchRel) { - RowTypePtr rowType = nullptr; // Get and validate the input types from extension. if (fetchRel.has_advanced_extension()) { const auto& extension = fetchRel.advanced_extension(); + TypePtr inputRowType; std::vector types; - if (!validateInputTypes(extension, types)) { + if (!parseVeloxType(extension, inputRowType) || !flattenVeloxType1(inputRowType, types)) { LOG_VALIDATION_MSG("Unsupported input types in FetchRel."); return false; } @@ -396,7 +423,6 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::FetchRel& fetchR for (auto colIdx = 0; colIdx < types.size(); colIdx++) { names.emplace_back(SubstraitParser::makeNodeName(inputPlanNodeId, colIdx)); } - rowType = std::make_shared(std::move(names), std::move(types)); } if (fetchRel.offset() < 0 || fetchRel.count() < 0) { @@ -412,8 +438,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::TopNRel& topNRel // Get and validate the input types from extension. if (topNRel.has_advanced_extension()) { const auto& extension = topNRel.advanced_extension(); + TypePtr inputRowType; std::vector types; - if (!validateInputTypes(extension, types)) { + if (!parseVeloxType(extension, inputRowType) || !flattenVeloxType1(inputRowType, types)) { LOG_VALIDATION_MSG("Unsupported input types in TopNRel."); return false; } @@ -457,8 +484,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::GenerateRel& gen return false; } const auto& extension = generateRel.advanced_extension(); + TypePtr inputRowType; std::vector types; - if (!validateInputTypes(extension, types)) { + if (!parseVeloxType(extension, inputRowType) || !flattenVeloxType1(inputRowType, types)) { LOG_VALIDATION_MSG("Validation failed for input types in GenerateRel."); return false; } @@ -487,8 +515,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::ExpandRel& expan // Get and validate the input types from extension. if (expandRel.has_advanced_extension()) { const auto& extension = expandRel.advanced_extension(); + TypePtr inputRowType; std::vector types; - if (!validateInputTypes(extension, types)) { + if (!parseVeloxType(extension, inputRowType) || !flattenVeloxType1(inputRowType, types)) { LOG_VALIDATION_MSG("Unsupported input types in ExpandRel."); return false; } @@ -571,8 +600,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::WindowRel& windo return false; } const auto& extension = windowRel.advanced_extension(); + TypePtr inputRowType; std::vector types; - if (!validateInputTypes(extension, types)) { + if (!parseVeloxType(extension, inputRowType) || !flattenVeloxType1(inputRowType, types)) { LOG_VALIDATION_MSG("Validation failed for input types in WindowRel."); return false; } @@ -680,7 +710,7 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::WindowRel& windo LOG_VALIDATION_MSG("in windowRel, the sorting key in Sort Operator only support field."); return false; } - exec::ExprSet exprSet({std::move(expression)}, execCtx_); + exec::ExprSet exprSet1({std::move(expression)}, execCtx_); } } @@ -699,8 +729,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::WindowGroupLimit return false; } const auto& extension = windowGroupLimitRel.advanced_extension(); + TypePtr inputRowType; std::vector types; - if (!validateInputTypes(extension, types)) { + if (!parseVeloxType(extension, inputRowType) || !flattenVeloxType1(inputRowType, types)) { LOG_VALIDATION_MSG("Validation failed for input types in WindowGroupLimitRel."); return false; } @@ -750,13 +781,61 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::WindowGroupLimit LOG_VALIDATION_MSG("in windowGroupLimitRel, the sorting key in Sort Operator only support field."); return false; } - exec::ExprSet exprSet({std::move(expression)}, execCtx_); + exec::ExprSet exprSet1({std::move(expression)}, execCtx_); } } return true; } +bool SubstraitToVeloxPlanValidator::validate(const ::substrait::SetRel& setRel) { + switch (setRel.op()) { + case ::substrait::SetRel_SetOp::SetRel_SetOp_SET_OP_UNION_ALL: { + for (int32_t i = 0; i < setRel.inputs_size(); ++i) { + const auto& input = setRel.inputs(i); + if (!validate(input)) { + LOG_VALIDATION_MSG("ProjectRel input"); + return false; + } + } + if (!setRel.has_advanced_extension()) { + LOG_VALIDATION_MSG("Input types are expected in SetRel."); + return false; + } + const auto& extension = setRel.advanced_extension(); + TypePtr inputRowType; + std::vector> childrenTypes; + if (!parseVeloxType(extension, inputRowType) || !flattenVeloxType2(inputRowType, childrenTypes)) { + LOG_VALIDATION_MSG("Validation failed for input types in SetRel."); + return false; + } + std::vector childrenRowTypes; + for (auto i = 0; i < childrenTypes.size(); ++i) { + auto& types = childrenTypes.at(i); + std::vector names; + names.reserve(types.size()); + for (auto colIdx = 0; colIdx < types.size(); colIdx++) { + names.emplace_back(SubstraitParser::makeNodeName(i, colIdx)); + } + childrenRowTypes.push_back(std::make_shared(std::move(names), std::move(types))); + } + + for (auto i = 1; i < childrenRowTypes.size(); ++i) { + if (!(childrenRowTypes[i]->equivalent(*childrenRowTypes[0]))) { + LOG_VALIDATION_MSG( + "All sources of the Set operation must have the same output type: " + childrenRowTypes[i]->toString() + + " vs. " + childrenRowTypes[0]->toString()); + return false; + } + } + return true; + } + default: + LOG_VALIDATION_MSG("Unsupported SetRel op: " + std::to_string(setRel.op())); + return false; + } +} + bool SubstraitToVeloxPlanValidator::validate(const ::substrait::SortRel& sortRel) { if (sortRel.has_input() && !validate(sortRel.input())) { return false; @@ -769,8 +848,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::SortRel& sortRel } const auto& extension = sortRel.advanced_extension(); + TypePtr inputRowType; std::vector types; - if (!validateInputTypes(extension, types)) { + if (!parseVeloxType(extension, inputRowType) || !flattenVeloxType1(inputRowType, types)) { LOG_VALIDATION_MSG("Validation failed for input types in SortRel."); return false; } @@ -822,8 +902,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::ProjectRel& proj return false; } const auto& extension = projectRel.advanced_extension(); + TypePtr inputRowType; std::vector types; - if (!validateInputTypes(extension, types)) { + if (!parseVeloxType(extension, inputRowType) || !flattenVeloxType1(inputRowType, types)) { LOG_VALIDATION_MSG("Validation failed for input types in ProjectRel."); return false; } @@ -865,8 +946,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::FilterRel& filte return false; } const auto& extension = filterRel.advanced_extension(); + TypePtr inputRowType; std::vector types; - if (!validateInputTypes(extension, types)) { + if (!parseVeloxType(extension, inputRowType) || !flattenVeloxType1(inputRowType, types)) { LOG_VALIDATION_MSG("Validation failed for input types in FilterRel."); return false; } @@ -938,8 +1020,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::JoinRel& joinRel } const auto& extension = joinRel.advanced_extension(); + TypePtr inputRowType; std::vector types; - if (!validateInputTypes(extension, types)) { + if (!parseVeloxType(extension, inputRowType) || !flattenVeloxType1(inputRowType, types)) { LOG_VALIDATION_MSG("Validation failed for input types in JoinRel."); return false; } @@ -991,8 +1074,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::CrossRel& crossR } const auto& extension = crossRel.advanced_extension(); + TypePtr inputRowType; std::vector types; - if (!validateInputTypes(extension, types)) { + if (!parseVeloxType(extension, inputRowType) || !flattenVeloxType1(inputRowType, types)) { logValidateMsg("Native validation failed due to: Validation failed for input types in CrossRel"); return false; } @@ -1070,11 +1154,13 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::AggregateRel& ag // Validate input types. if (aggRel.has_advanced_extension()) { + TypePtr inputRowType; std::vector types; const auto& extension = aggRel.advanced_extension(); // Aggregate always has advanced extension for streaming aggregate optimization, // but only some of them have enhancement for validation. - if (extension.has_enhancement() && !validateInputTypes(extension, types)) { + if (extension.has_enhancement() && + (!parseVeloxType(extension, inputRowType) || !flattenVeloxType1(inputRowType, types))) { LOG_VALIDATION_MSG("Validation failed for input types in AggregateRel."); return false; } @@ -1266,7 +1352,10 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::Rel& rel) { return validate(rel.write()); } else if (rel.has_windowgrouplimit()) { return validate(rel.windowgrouplimit()); + } else if (rel.has_set()) { + return validate(rel.set()); } else { + LOG_VALIDATION_MSG("Unsupported relation type: " + rel.GetTypeName()); return false; } } diff --git a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.h b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.h index 1fe174928fd96..0c8d882ca0314 100644 --- a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.h +++ b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.h @@ -61,6 +61,9 @@ class SubstraitToVeloxPlanValidator { /// Used to validate whether the computing of this WindowGroupLimit is supported. bool validate(const ::substrait::WindowGroupLimitRel& windowGroupLimitRel); + /// Used to validate whether the computing of this Set is supported. + bool validate(const ::substrait::SetRel& setRel); + /// Used to validate whether the computing of this Aggregation is supported. bool validate(const ::substrait::AggregateRel& aggRel); @@ -103,9 +106,17 @@ class SubstraitToVeloxPlanValidator { std::vector validateLog_; - /// Used to get types from advanced extension and validate them. - bool validateInputTypes(const ::substrait::extensions::AdvancedExtension& extension, std::vector& types); + /// Used to get types from advanced extension and validate them, then convert to a Velox type that has arbitrary + /// levels of nesting. + bool parseVeloxType(const ::substrait::extensions::AdvancedExtension& extension, TypePtr& out); + + /// Flattens a Velox type with single level of nesting into a std::vector of child types. + bool flattenVeloxType1(const TypePtr& type, std::vector& out); + + /// Flattens a Velox type with two level of nesting into a dual-nested std::vector of child types. + bool flattenVeloxType2(const TypePtr& type, std::vector>& out); + /// Validate aggregate rel. bool validateAggRelFunctionType(const ::substrait::AggregateRel& substraitAgg); /// Validate the round scalar function. diff --git a/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java b/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java index def1dca0a0286..7d19311808470 100644 --- a/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java +++ b/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/RelBuilder.java @@ -27,6 +27,7 @@ import io.substrait.proto.CrossRel; import io.substrait.proto.JoinRel; +import io.substrait.proto.SetRel; import io.substrait.proto.SortField; import org.apache.spark.sql.catalyst.expressions.Attribute; @@ -317,4 +318,20 @@ public static RelNode makeGenerateRel( context.registerRelToOperator(operatorId); return new GenerateRelNode(input, generator, childOutput, extensionNode, outer); } + + public static RelNode makeSetRel( + List inputs, SetRel.SetOp setOp, SubstraitContext context, Long operatorId) { + context.registerRelToOperator(operatorId); + return new SetRelNode(inputs, setOp); + } + + public static RelNode makeSetRel( + List inputs, + SetRel.SetOp setOp, + AdvancedExtensionNode extensionNode, + SubstraitContext context, + Long operatorId) { + context.registerRelToOperator(operatorId); + return new SetRelNode(inputs, setOp, extensionNode); + } } diff --git a/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/SetRelNode.java b/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/SetRelNode.java new file mode 100644 index 0000000000000..ddcfb1701dd10 --- /dev/null +++ b/gluten-substrait/src/main/java/org/apache/gluten/substrait/rel/SetRelNode.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.substrait.rel; + +import org.apache.gluten.substrait.extensions.AdvancedExtensionNode; + +import io.substrait.proto.Rel; +import io.substrait.proto.RelCommon; +import io.substrait.proto.SetRel; + +import java.io.Serializable; +import java.util.List; + +public class SetRelNode implements RelNode, Serializable { + private final List inputs; + private final SetRel.SetOp setOp; + private final AdvancedExtensionNode extensionNode; + + public SetRelNode(List inputs, SetRel.SetOp setOp, AdvancedExtensionNode extensionNode) { + this.inputs = inputs; + this.setOp = setOp; + this.extensionNode = extensionNode; + } + + public SetRelNode(List inputs, SetRel.SetOp setOp) { + this(inputs, setOp, null); + } + + @Override + public Rel toProtobuf() { + final RelCommon.Builder relCommonBuilder = RelCommon.newBuilder(); + relCommonBuilder.setDirect(RelCommon.Direct.newBuilder()); + final SetRel.Builder setBuilder = SetRel.newBuilder(); + setBuilder.setCommon(relCommonBuilder.build()); + if (inputs != null) { + for (RelNode input : inputs) { + setBuilder.addInputs(input.toProtobuf()); + } + } + setBuilder.setOp(setOp); + if (extensionNode != null) { + setBuilder.setAdvancedExtension(extensionNode.toProtobuf()); + } + final Rel.Builder builder = Rel.newBuilder(); + builder.setSet(setBuilder.build()); + return builder.build(); + } +} diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/MetricsApi.scala b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/MetricsApi.scala index c67d4b5f8876b..453cfab4e487e 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/MetricsApi.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/backendsapi/MetricsApi.scala @@ -126,6 +126,10 @@ trait MetricsApi extends Serializable { def genSampleTransformerMetricsUpdater(metrics: Map[String, SQLMetric]): MetricsUpdater + def genUnionTransformerMetrics(sparkContext: SparkContext): Map[String, SQLMetric] + + def genUnionTransformerMetricsUpdater(metrics: Map[String, SQLMetric]): MetricsUpdater + def genColumnarInMemoryTableMetrics(sparkContext: SparkContext): Map[String, SQLMetric] = Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala index f9755605cab23..ac8e610956dcd 100644 --- a/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/BasicPhysicalOperatorTransformer.scala @@ -261,10 +261,11 @@ abstract class ProjectExecTransformerBase(val list: Seq[NamedExpression], val in } } -// An alternatives for UnionExec. +// An alternative for UnionExec. case class ColumnarUnionExec(children: Seq[SparkPlan]) extends ValidatablePlan { children.foreach { case w: WholeStageTransformer => + // FIXME: Avoid such practice for plan immutability. w.setOutputSchemaForPlan(output) case _ => } diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/execution/UnionExecTransformer.scala b/gluten-substrait/src/main/scala/org/apache/gluten/execution/UnionExecTransformer.scala new file mode 100644 index 0000000000000..d27558746a403 --- /dev/null +++ b/gluten-substrait/src/main/scala/org/apache/gluten/execution/UnionExecTransformer.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.execution + +import org.apache.gluten.backendsapi.BackendsApiManager +import org.apache.gluten.expression.ConverterUtils +import org.apache.gluten.extension.ValidationResult +import org.apache.gluten.metrics.MetricsUpdater +import org.apache.gluten.substrait.`type`.TypeBuilder +import org.apache.gluten.substrait.SubstraitContext +import org.apache.gluten.substrait.extensions.ExtensionBuilder +import org.apache.gluten.substrait.rel.{RelBuilder, RelNode} + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.execution.{SparkPlan, UnionExec} +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.vectorized.ColumnarBatch + +import io.substrait.proto.SetRel.SetOp + +import scala.collection.JavaConverters._ + +/** Transformer for UnionExec. Note: Spark's UnionExec represents a SQL UNION ALL. */ +case class UnionExecTransformer(children: Seq[SparkPlan]) extends TransformSupport { + private val union = UnionExec(children) + + // Note: "metrics" is made transient to avoid sending driver-side metrics to tasks. + @transient override lazy val metrics: Map[String, SQLMetric] = + BackendsApiManager.getMetricsApiInstance.genUnionTransformerMetrics(sparkContext) + + override def output: Seq[Attribute] = union.output + + override def columnarInputRDDs: Seq[RDD[ColumnarBatch]] = children.flatMap(getColumnarInputRDDs) + + override def metricsUpdater(): MetricsUpdater = + BackendsApiManager.getMetricsApiInstance.genUnionTransformerMetricsUpdater(metrics) + + override protected def withNewChildrenInternal(newChildren: IndexedSeq[SparkPlan]): SparkPlan = + copy(children = newChildren) + + override protected def doValidateInternal(): ValidationResult = { + val context = new SubstraitContext + val operatorId = context.nextOperatorId(this.nodeName) + val relNode = getRelNode(context, operatorId, children.map(_.output), null, true) + doNativeValidation(context, relNode) + } + + override protected def doTransform(context: SubstraitContext): TransformContext = { + val childrenCtx = children.map(_.asInstanceOf[TransformSupport].transform(context)) + val operatorId = context.nextOperatorId(this.nodeName) + val relNode = + getRelNode(context, operatorId, children.map(_.output), childrenCtx.map(_.root), false) + TransformContext(output, relNode) + } + + private def getRelNode( + context: SubstraitContext, + operatorId: Long, + inputAttributes: Seq[Seq[Attribute]], + inputs: Seq[RelNode], + validation: Boolean): RelNode = { + if (validation) { + // Use the second level of nesting to represent N way inputs. + val inputTypeNodes = + inputAttributes.map( + attributes => + attributes.map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)).asJava) + val extensionNode = ExtensionBuilder.makeAdvancedExtension( + BackendsApiManager.getTransformerApiInstance.packPBMessage( + TypeBuilder + .makeStruct( + false, + inputTypeNodes.map(nodes => TypeBuilder.makeStruct(false, nodes)).asJava) + .toProtobuf)) + return RelBuilder.makeSetRel( + inputs.asJava, + SetOp.SET_OP_UNION_ALL, + extensionNode, + context, + operatorId) + } + RelBuilder.makeSetRel(inputs.asJava, SetOp.SET_OP_UNION_ALL, context, operatorId) + } +} diff --git a/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/UnionTransformerRule.scala b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/UnionTransformerRule.scala new file mode 100644 index 0000000000000..f0eea08018dd4 --- /dev/null +++ b/gluten-substrait/src/main/scala/org/apache/gluten/extension/columnar/UnionTransformerRule.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.extension.columnar + +import org.apache.gluten.GlutenConfig +import org.apache.gluten.execution.{ColumnarUnionExec, UnionExecTransformer} + +import org.apache.spark.sql.catalyst.plans.physical.UnknownPartitioning +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.SparkPlan + +/** + * Replace ColumnarUnionExec with UnionExecTransformer if possible. + * + * The rule is not included in [[org.apache.gluten.extension.columnar.heuristic.HeuristicTransform]] + * or [[org.apache.gluten.extension.columnar.enumerated.EnumeratedTransform]] because it relies on + * children's output partitioning to be fully provided. + */ +case class UnionTransformerRule() extends Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = { + if (!GlutenConfig.getConf.enableNativeUnion) { + return plan + } + plan.transformUp { + case plan: ColumnarUnionExec => + val transformer = UnionExecTransformer(plan.children) + if (sameNumPartitions(plan.children) && validate(transformer)) { + transformer + } else { + plan + } + } + } + + private def sameNumPartitions(plans: Seq[SparkPlan]): Boolean = { + val partitioning = plans.map(_.outputPartitioning) + if (partitioning.exists(p => p.isInstanceOf[UnknownPartitioning])) { + return false + } + val numPartitions = plans.map(_.outputPartitioning.numPartitions) + numPartitions.forall(_ == numPartitions.head) + } + + private def validate(union: UnionExecTransformer): Boolean = { + union.doValidate().ok() + } +} diff --git a/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala b/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala index fd250834d078c..08081fadb5f9a 100644 --- a/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala +++ b/gluten-substrait/src/test/scala/org/apache/gluten/execution/WholeStageTransformerSuite.scala @@ -176,25 +176,39 @@ abstract class WholeStageTransformerSuite result } + protected def compareResultsAgainstVanillaSpark( + sql: String, + compareResult: Boolean = true, + customCheck: DataFrame => Unit, + noFallBack: Boolean = true, + cache: Boolean = false): DataFrame = { + compareDfResultsAgainstVanillaSpark( + () => spark.sql(sql), + compareResult, + customCheck, + noFallBack, + cache) + } + /** * run a query with native engine as well as vanilla spark then compare the result set for * correctness check */ - protected def compareResultsAgainstVanillaSpark( - sqlStr: String, + protected def compareDfResultsAgainstVanillaSpark( + dataframe: () => DataFrame, compareResult: Boolean = true, customCheck: DataFrame => Unit, noFallBack: Boolean = true, cache: Boolean = false): DataFrame = { var expected: Seq[Row] = null withSQLConf(vanillaSparkConfs(): _*) { - val df = spark.sql(sqlStr) + val df = dataframe() expected = df.collect() } - // By default we will fallabck complex type scan but here we should allow + // By default, we will fallback complex type scan but here we should allow // to test support of complex type spark.conf.set("spark.gluten.sql.complexType.scan.fallback.enabled", "false"); - val df = spark.sql(sqlStr) + val df = dataframe() if (cache) { df.cache() } @@ -239,7 +253,12 @@ abstract class WholeStageTransformerSuite noFallBack: Boolean = true, cache: Boolean = false)(customCheck: DataFrame => Unit): DataFrame = { - compareResultsAgainstVanillaSpark(sqlStr, compareResult, customCheck, noFallBack, cache) + compareDfResultsAgainstVanillaSpark( + () => spark.sql(sqlStr), + compareResult, + customCheck, + noFallBack, + cache) } /** @@ -256,8 +275,8 @@ abstract class WholeStageTransformerSuite customCheck: DataFrame => Unit, noFallBack: Boolean = true, compareResult: Boolean = true): Unit = - compareResultsAgainstVanillaSpark( - tpchSQL(queryNum, tpchQueries), + compareDfResultsAgainstVanillaSpark( + () => spark.sql(tpchSQL(queryNum, tpchQueries)), compareResult, customCheck, noFallBack) diff --git a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala index f643ad7eed0ee..9ae4c0ce90eb4 100644 --- a/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala +++ b/shims/common/src/main/scala/org/apache/gluten/GlutenConfig.scala @@ -86,6 +86,8 @@ class GlutenConfig(conf: SQLConf) extends Logging { def enableColumnarUnion: Boolean = conf.getConf(COLUMNAR_UNION_ENABLED) + def enableNativeUnion: Boolean = conf.getConf(NATIVE_UNION_ENABLED) + def enableColumnarExpand: Boolean = conf.getConf(COLUMNAR_EXPAND_ENABLED) def enableColumnarBroadcastExchange: Boolean = conf.getConf(COLUMNAR_BROADCAST_EXCHANGE_ENABLED) @@ -1022,6 +1024,13 @@ object GlutenConfig { .booleanConf .createWithDefault(true) + val NATIVE_UNION_ENABLED = + buildConf("spark.gluten.sql.native.union") + .internal() + .doc("Enable or disable native union where computation is completely offloaded to backend.") + .booleanConf + .createWithDefault(false) + val COLUMNAR_EXPAND_ENABLED = buildConf("spark.gluten.sql.columnar.expand") .internal() From 11d658ec4541fcf2c470c64e745a0c8c79ecfb32 Mon Sep 17 00:00:00 2001 From: kevinyhzou <37431499+KevinyhZou@users.noreply.github.com> Date: Tue, 3 Dec 2024 11:44:49 +0800 Subject: [PATCH 3/5] remove least and greatest function (#8091) --- .../Functions/FunctionGreatestLeast.h | 80 ------------------- .../Functions/SparkFunctionGreatest.cpp | 38 --------- .../Functions/SparkFunctionLeast.cpp | 38 --------- .../CommonScalarFunctionParser.cpp | 4 +- .../clickhouse/ClickHouseTestSettings.scala | 2 - .../clickhouse/ClickHouseTestSettings.scala | 2 - .../clickhouse/ClickHouseTestSettings.scala | 2 - .../clickhouse/ClickHouseTestSettings.scala | 2 - 8 files changed, 2 insertions(+), 166 deletions(-) delete mode 100644 cpp-ch/local-engine/Functions/FunctionGreatestLeast.h delete mode 100644 cpp-ch/local-engine/Functions/SparkFunctionGreatest.cpp delete mode 100644 cpp-ch/local-engine/Functions/SparkFunctionLeast.cpp diff --git a/cpp-ch/local-engine/Functions/FunctionGreatestLeast.h b/cpp-ch/local-engine/Functions/FunctionGreatestLeast.h deleted file mode 100644 index e9b66df84ef02..0000000000000 --- a/cpp-ch/local-engine/Functions/FunctionGreatestLeast.h +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include -#include -#include -#include -#include - -namespace DB -{ -namespace ErrorCodes -{ - extern const int NUMBER_OF_ARGUMENTS_DOESNT_MATCH; -} -} -namespace local_engine -{ -template -class FunctionGreatestestLeast : public DB::FunctionLeastGreatestGeneric -{ -public: - bool useDefaultImplementationForNulls() const override { return false; } - virtual String getName() const = 0; - -private: - DB::DataTypePtr getReturnTypeImpl(const DB::DataTypes & types) const override - { - if (types.empty()) - throw DB::Exception(DB::ErrorCodes::NUMBER_OF_ARGUMENTS_DOESNT_MATCH, "Function {} cannot be called without arguments", getName()); - return makeNullable(getLeastSupertype(types)); - } - - DB::ColumnPtr executeImpl(const DB::ColumnsWithTypeAndName & arguments, const DB::DataTypePtr & result_type, size_t input_rows_count) const override - { - size_t num_arguments = arguments.size(); - DB::Columns converted_columns(num_arguments); - for (size_t arg = 0; arg < num_arguments; ++arg) - converted_columns[arg] = castColumn(arguments[arg], result_type)->convertToFullColumnIfConst(); - auto result_column = result_type->createColumn(); - result_column->reserve(input_rows_count); - for (size_t row_num = 0; row_num < input_rows_count; ++row_num) - { - size_t best_arg = 0; - for (size_t arg = 1; arg < num_arguments; ++arg) - { - if constexpr (kind == DB::LeastGreatest::Greatest) - { - auto cmp_result = converted_columns[arg]->compareAt(row_num, row_num, *converted_columns[best_arg], -1); - if (cmp_result > 0) - best_arg = arg; - } - else - { - auto cmp_result = converted_columns[arg]->compareAt(row_num, row_num, *converted_columns[best_arg], 1); - if (cmp_result < 0) - best_arg = arg; - } - } - result_column->insertFrom(*converted_columns[best_arg], row_num); - } - return result_column; - } -}; - -} diff --git a/cpp-ch/local-engine/Functions/SparkFunctionGreatest.cpp b/cpp-ch/local-engine/Functions/SparkFunctionGreatest.cpp deleted file mode 100644 index 920fe1b9c9cce..0000000000000 --- a/cpp-ch/local-engine/Functions/SparkFunctionGreatest.cpp +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -namespace local_engine -{ -class SparkFunctionGreatest : public FunctionGreatestestLeast -{ -public: - static constexpr auto name = "sparkGreatest"; - static DB::FunctionPtr create(DB::ContextPtr) { return std::make_shared(); } - SparkFunctionGreatest() = default; - ~SparkFunctionGreatest() override = default; - String getName() const override - { - return name; - } -}; - -REGISTER_FUNCTION(SparkGreatest) -{ - factory.registerFunction(); -} -} diff --git a/cpp-ch/local-engine/Functions/SparkFunctionLeast.cpp b/cpp-ch/local-engine/Functions/SparkFunctionLeast.cpp deleted file mode 100644 index 70aafdf072095..0000000000000 --- a/cpp-ch/local-engine/Functions/SparkFunctionLeast.cpp +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -namespace local_engine -{ -class SparkFunctionLeast : public FunctionGreatestestLeast -{ -public: - static constexpr auto name = "sparkLeast"; - static DB::FunctionPtr create(DB::ContextPtr) { return std::make_shared(); } - SparkFunctionLeast() = default; - ~SparkFunctionLeast() override = default; - String getName() const override - { - return name; - } -}; - -REGISTER_FUNCTION(SparkLeast) -{ - factory.registerFunction(); -} -} diff --git a/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp b/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp index ec8b4e0d12bf3..e4a56194c1711 100644 --- a/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp +++ b/cpp-ch/local-engine/Parser/scalar_function_parser/CommonScalarFunctionParser.cpp @@ -98,8 +98,8 @@ REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Unhex, unhex, unhex); REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Hypot, hypot, hypot); REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Sign, sign, sign); REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Radians, radians, radians); -REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Greatest, greatest, sparkGreatest); -REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Least, least, sparkLeast); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Greatest, greatest, greatest); +REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Least, least, least); REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Rand, rand, randCanonical); REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Bin, bin, sparkBin); REGISTER_COMMON_SCALAR_FUNCTION_PARSER(Rint, rint, sparkRint); diff --git a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 2eb5bd11ffbe0..36d5b5177c6b3 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -554,8 +554,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("SPARK-17617: % (Remainder) double % double on super big double") .exclude("Abs") .exclude("pmod") - .exclude("function least") - .exclude("function greatest") .exclude("SPARK-28322: IntegralDivide supports decimal type") .exclude("SPARK-33008: division by zero on divide-like operations returns incorrect result") .exclude("SPARK-34920: error class") diff --git a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index a7bf5d4da903e..b9bf4e1ac40fa 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -570,8 +570,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("SPARK-17617: % (Remainder) double % double on super big double") .exclude("Abs") .exclude("pmod") - .exclude("function least") - .exclude("function greatest") .exclude("SPARK-28322: IntegralDivide supports decimal type") .exclude("SPARK-33008: division by zero on divide-like operations returns incorrect result") .exclude("SPARK-34920: error class") diff --git a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index b7e3905740fbc..a407c5d682478 100644 --- a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -513,8 +513,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("SPARK-17617: % (Remainder) double % double on super big double") .exclude("Abs") .exclude("pmod") - .exclude("function least") - .exclude("function greatest") .exclude("SPARK-28322: IntegralDivide supports decimal type") .exclude("SPARK-33008: division by zero on divide-like operations returns incorrect result") .exclude("SPARK-34920: error class") diff --git a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala index 8ce145735dc31..9c22af0434af1 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/clickhouse/ClickHouseTestSettings.scala @@ -513,8 +513,6 @@ class ClickHouseTestSettings extends BackendTestSettings { .exclude("SPARK-17617: % (Remainder) double % double on super big double") .exclude("Abs") .exclude("pmod") - .exclude("function least") - .exclude("function greatest") .exclude("SPARK-28322: IntegralDivide supports decimal type") .exclude("SPARK-33008: division by zero on divide-like operations returns incorrect result") .exclude("SPARK-34920: error class") From 4b0a1c10939fa1314e5d892173dd8ed9afb3441e Mon Sep 17 00:00:00 2001 From: Zhen Li <10524738+zhli1142015@users.noreply.github.com> Date: Tue, 3 Dec 2024 11:54:43 +0800 Subject: [PATCH 4/5] [VL] Enable Spark JsonExpressionsSuite (#8099) --- .../gluten/utils/velox/VeloxTestSettings.scala | 16 ++++++++++++++++ .../gluten/utils/velox/VeloxTestSettings.scala | 18 +++++++++++++++++- .../gluten/utils/velox/VeloxTestSettings.scala | 18 +++++++++++++++++- .../gluten/utils/velox/VeloxTestSettings.scala | 18 +++++++++++++++++- 4 files changed, 67 insertions(+), 3 deletions(-) diff --git a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala index 276a8a0f7af21..15495270a189b 100644 --- a/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala +++ b/gluten-ut/spark32/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala @@ -364,6 +364,22 @@ class VeloxTestSettings extends BackendTestSettings { // Rewrite the following two tests in GlutenDatasetSuite. .exclude("dropDuplicates: columns with same column name") .exclude("groupBy.as") + enableSuite[GlutenJsonExpressionsSuite] + // https://github.com/apache/incubator-gluten/issues/8102 + .exclude("$.store.book") + .exclude("$") + .exclude("$.store.book[0]") + .exclude("$.store.book[*]") + .exclude("$.store.book[*].category") + .exclude("$.store.book[*].isbn") + .exclude("$.store.book[*].reader") + .exclude("$.store.basket[*]") + .exclude("$.store.basket[*][0]") + .exclude("$.store.basket[0][*]") + .exclude("$.store.basket[*][*]") + .exclude("$.store.basket[0][*].b") + // Exception class different. + .exclude("from_json - invalid data") enableSuite[GlutenJsonFunctionsSuite] // Velox does not support single quotes in get_json_object function. .exclude("function get_json_object - support single quotes") diff --git a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala index 08fab69c849b9..407b9c8b95cc3 100644 --- a/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala +++ b/gluten-ut/spark33/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala @@ -20,7 +20,7 @@ import org.apache.gluten.utils.{BackendTestSettings, SQLQueryTestSettings} import org.apache.spark.GlutenSortShuffleSuite import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.expressions.{GlutenAnsiCastSuiteWithAnsiModeOff, GlutenAnsiCastSuiteWithAnsiModeOn, GlutenArithmeticExpressionSuite, GlutenBitwiseExpressionsSuite, GlutenCastSuite, GlutenCastSuiteWithAnsiModeOn, GlutenCollectionExpressionsSuite, GlutenComplexTypeSuite, GlutenConditionalExpressionSuite, GlutenDateExpressionsSuite, GlutenDecimalExpressionSuite, GlutenDecimalPrecisionSuite, GlutenHashExpressionsSuite, GlutenHigherOrderFunctionsSuite, GlutenIntervalExpressionsSuite, GlutenLiteralExpressionSuite, GlutenMathExpressionsSuite, GlutenMiscExpressionsSuite, GlutenNondeterministicSuite, GlutenNullExpressionsSuite, GlutenPredicateSuite, GlutenRandomSuite, GlutenRegexpExpressionsSuite, GlutenSortOrderExpressionsSuite, GlutenStringExpressionsSuite, GlutenTryCastSuite} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.connector._ import org.apache.spark.sql.errors.{GlutenQueryCompilationErrorsDSv2Suite, GlutenQueryExecutionErrorsSuite, GlutenQueryParsingErrorsSuite} import org.apache.spark.sql.execution._ @@ -142,6 +142,22 @@ class VeloxTestSettings extends BackendTestSettings { enableSuite[GlutenHashExpressionsSuite] enableSuite[GlutenHigherOrderFunctionsSuite] enableSuite[GlutenIntervalExpressionsSuite] + enableSuite[GlutenJsonExpressionsSuite] + // https://github.com/apache/incubator-gluten/issues/8102 + .exclude("$.store.book") + .exclude("$") + .exclude("$.store.book[0]") + .exclude("$.store.book[*]") + .exclude("$.store.book[*].category") + .exclude("$.store.book[*].isbn") + .exclude("$.store.book[*].reader") + .exclude("$.store.basket[*]") + .exclude("$.store.basket[*][0]") + .exclude("$.store.basket[0][*]") + .exclude("$.store.basket[*][*]") + .exclude("$.store.basket[0][*].b") + // Exception class different. + .exclude("from_json - invalid data") enableSuite[GlutenJsonFunctionsSuite] // Velox does not support single quotes in get_json_object function. .exclude("function get_json_object - support single quotes") diff --git a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala index df79b663bcbeb..dbb01fbe70678 100644 --- a/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala +++ b/gluten-ut/spark34/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala @@ -20,7 +20,7 @@ import org.apache.gluten.utils.{BackendTestSettings, SQLQueryTestSettings} import org.apache.spark.GlutenSortShuffleSuite import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.expressions.{GlutenArithmeticExpressionSuite, GlutenBitwiseExpressionsSuite, GlutenCastSuite, GlutenCollectionExpressionsSuite, GlutenComplexTypeSuite, GlutenConditionalExpressionSuite, GlutenDateExpressionsSuite, GlutenDecimalExpressionSuite, GlutenDecimalPrecisionSuite, GlutenHashExpressionsSuite, GlutenHigherOrderFunctionsSuite, GlutenIntervalExpressionsSuite, GlutenLiteralExpressionSuite, GlutenMathExpressionsSuite, GlutenMiscExpressionsSuite, GlutenNondeterministicSuite, GlutenNullExpressionsSuite, GlutenPredicateSuite, GlutenRandomSuite, GlutenRegexpExpressionsSuite, GlutenSortOrderExpressionsSuite, GlutenStringExpressionsSuite, GlutenTryEvalSuite} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.connector.{GlutenDataSourceV2DataFrameSessionCatalogSuite, GlutenDataSourceV2DataFrameSuite, GlutenDataSourceV2FunctionSuite, GlutenDataSourceV2SQLSessionCatalogSuite, GlutenDataSourceV2SQLSuiteV1Filter, GlutenDataSourceV2SQLSuiteV2Filter, GlutenDataSourceV2Suite, GlutenDeleteFromTableSuite, GlutenDeltaBasedDeleteFromTableSuite, GlutenFileDataSourceV2FallBackSuite, GlutenGroupBasedDeleteFromTableSuite, GlutenKeyGroupedPartitioningSuite, GlutenLocalScanSuite, GlutenMetadataColumnSuite, GlutenSupportsCatalogOptionsSuite, GlutenTableCapabilityCheckSuite, GlutenWriteDistributionAndOrderingSuite} import org.apache.spark.sql.errors.{GlutenQueryCompilationErrorsDSv2Suite, GlutenQueryCompilationErrorsSuite, GlutenQueryExecutionErrorsSuite, GlutenQueryParsingErrorsSuite} import org.apache.spark.sql.execution.{FallbackStrategiesSuite, GlutenBroadcastExchangeSuite, GlutenCoalesceShufflePartitionsSuite, GlutenExchangeSuite, GlutenLocalBroadcastExchangeSuite, GlutenReplaceHashWithSortAggSuite, GlutenReuseExchangeAndSubquerySuite, GlutenSameResultSuite, GlutenSortSuite, GlutenSQLAggregateFunctionSuite, GlutenSQLWindowFunctionSuite, GlutenTakeOrderedAndProjectSuite} @@ -120,6 +120,22 @@ class VeloxTestSettings extends BackendTestSettings { enableSuite[GlutenHashExpressionsSuite] enableSuite[GlutenHigherOrderFunctionsSuite] enableSuite[GlutenIntervalExpressionsSuite] + enableSuite[GlutenJsonExpressionsSuite] + // https://github.com/apache/incubator-gluten/issues/8102 + .exclude("$.store.book") + .exclude("$") + .exclude("$.store.book[0]") + .exclude("$.store.book[*]") + .exclude("$.store.book[*].category") + .exclude("$.store.book[*].isbn") + .exclude("$.store.book[*].reader") + .exclude("$.store.basket[*]") + .exclude("$.store.basket[*][0]") + .exclude("$.store.basket[0][*]") + .exclude("$.store.basket[*][*]") + .exclude("$.store.basket[0][*].b") + // Exception class different. + .exclude("from_json - invalid data") enableSuite[GlutenJsonFunctionsSuite] // Velox does not support single quotes in get_json_object function. .exclude("function get_json_object - support single quotes") diff --git a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala index cb2151fe4698f..f5a1a076956e0 100644 --- a/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala +++ b/gluten-ut/spark35/src/test/scala/org/apache/gluten/utils/velox/VeloxTestSettings.scala @@ -20,7 +20,7 @@ import org.apache.gluten.utils.{BackendTestSettings, SQLQueryTestSettings} import org.apache.spark.GlutenSortShuffleSuite import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.expressions.{GlutenArithmeticExpressionSuite, GlutenBitwiseExpressionsSuite, GlutenCastSuite, GlutenCollectionExpressionsSuite, GlutenComplexTypeSuite, GlutenConditionalExpressionSuite, GlutenDateExpressionsSuite, GlutenDecimalExpressionSuite, GlutenDecimalPrecisionSuite, GlutenHashExpressionsSuite, GlutenHigherOrderFunctionsSuite, GlutenIntervalExpressionsSuite, GlutenLiteralExpressionSuite, GlutenMathExpressionsSuite, GlutenMiscExpressionsSuite, GlutenNondeterministicSuite, GlutenNullExpressionsSuite, GlutenPredicateSuite, GlutenRandomSuite, GlutenRegexpExpressionsSuite, GlutenSortOrderExpressionsSuite, GlutenStringExpressionsSuite, GlutenTryEvalSuite} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.connector._ import org.apache.spark.sql.errors.{GlutenQueryCompilationErrorsDSv2Suite, GlutenQueryCompilationErrorsSuite, GlutenQueryExecutionErrorsSuite, GlutenQueryParsingErrorsSuite} import org.apache.spark.sql.execution._ @@ -121,6 +121,22 @@ class VeloxTestSettings extends BackendTestSettings { enableSuite[GlutenHashExpressionsSuite] enableSuite[GlutenHigherOrderFunctionsSuite] enableSuite[GlutenIntervalExpressionsSuite] + enableSuite[GlutenJsonExpressionsSuite] + // https://github.com/apache/incubator-gluten/issues/8102 + .exclude("$.store.book") + .exclude("$") + .exclude("$.store.book[0]") + .exclude("$.store.book[*]") + .exclude("$.store.book[*].category") + .exclude("$.store.book[*].isbn") + .exclude("$.store.book[*].reader") + .exclude("$.store.basket[*]") + .exclude("$.store.basket[*][0]") + .exclude("$.store.basket[0][*]") + .exclude("$.store.basket[*][*]") + .exclude("$.store.basket[0][*].b") + // Exception class different. + .exclude("from_json - invalid data") enableSuite[GlutenJsonFunctionsSuite] // * in get_json_object expression not supported in velox .exclude("SPARK-42782: Hive compatibility check for get_json_object") From 2346584a72cae67494a073a4f0141386bad723db Mon Sep 17 00:00:00 2001 From: Wenzheng Liu Date: Tue, 3 Dec 2024 15:57:26 +0800 Subject: [PATCH 5/5] [GLUTEN-7028][CH][Part-11] Support write parquet files with bucket (#8052) * [GLUTEN-7028][CH] Support write parquet files with bucket * [GLUTEN-7028][CH] Fix comment --- .../execution/FileDeltaColumnarWrite.scala | 6 +- .../backendsapi/clickhouse/CHBackend.scala | 11 +- .../clickhouse/CHIteratorApi.scala | 13 +-- .../backendsapi/clickhouse/CHRuleApi.scala | 1 + .../clickhouse/RuntimeSettings.scala | 6 - .../extension/WriteFilesWithBucketValue.scala | 76 ++++++++++++ .../spark/sql/execution/CHColumnarWrite.scala | 46 ++++++-- ...lutenClickHouseNativeWriteTableSuite.scala | 15 +-- .../Parser/RelParsers/WriteRelParser.cpp | 14 +-- .../Parser/RelParsers/WriteRelParser.h | 1 - .../Storages/MergeTree/SparkMergeTreeSink.h | 6 + .../Storages/Output/NormalFileWriter.cpp | 4 +- .../Storages/Output/NormalFileWriter.h | 110 +++++++++++++++--- .../Storages/Output/OutputFormatFile.cpp | 1 - .../tests/gtest_write_pipeline.cpp | 6 +- 15 files changed, 242 insertions(+), 74 deletions(-) create mode 100644 backends-clickhouse/src/main/scala/org/apache/gluten/extension/WriteFilesWithBucketValue.scala diff --git a/backends-clickhouse/src/main/delta-32/org/apache/spark/sql/execution/FileDeltaColumnarWrite.scala b/backends-clickhouse/src/main/delta-32/org/apache/spark/sql/execution/FileDeltaColumnarWrite.scala index 784614152f639..bf6b0c0074dc8 100644 --- a/backends-clickhouse/src/main/delta-32/org/apache/spark/sql/execution/FileDeltaColumnarWrite.scala +++ b/backends-clickhouse/src/main/delta-32/org/apache/spark/sql/execution/FileDeltaColumnarWrite.scala @@ -108,13 +108,15 @@ case class FileDeltaColumnarWrite( * {{{ * part-00000-7d672b28-c079-4b00-bb0a-196c15112918-c000.snappy.parquet * => - * part-00000-{}.snappy.parquet + * part-00000-{id}.snappy.parquet * }}} */ val guidPattern = """.*-([0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12})(?:-c(\d+)\..*)?$""".r val fileNamePattern = - guidPattern.replaceAllIn(writeFileName, m => writeFileName.replace(m.group(1), "{}")) + guidPattern.replaceAllIn( + writeFileName, + m => writeFileName.replace(m.group(1), FileNamePlaceHolder.ID)) logDebug(s"Native staging write path: $writePath and with pattern: $fileNamePattern") val settings = diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala index c6c8acf7050e3..e5eb91b69bdb0 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHBackend.scala @@ -246,20 +246,11 @@ object CHBackendSettings extends BackendSettingsApi with Logging { } } - def validateBucketSpec(): Option[String] = { - if (bucketSpec.nonEmpty) { - Some("Unsupported native write: bucket write is not supported.") - } else { - None - } - } - validateCompressionCodec() .orElse(validateFileFormat()) .orElse(validateFieldMetadata()) .orElse(validateDateTypes()) - .orElse(validateWriteFilesOptions()) - .orElse(validateBucketSpec()) match { + .orElse(validateWriteFilesOptions()) match { case Some(reason) => ValidationResult.failed(reason) case _ => ValidationResult.succeeded } diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala index ff268b95d8deb..878e27a5b8c28 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHIteratorApi.scala @@ -26,7 +26,7 @@ import org.apache.gluten.sql.shims.SparkShimLoader import org.apache.gluten.substrait.plan.PlanNode import org.apache.gluten.substrait.rel._ import org.apache.gluten.substrait.rel.LocalFilesNode.ReadFileFormat -import org.apache.gluten.vectorized.{BatchIterator, CHNativeExpressionEvaluator, CloseableCHColumnBatchIterator, NativeExpressionEvaluator} +import org.apache.gluten.vectorized.{BatchIterator, CHNativeExpressionEvaluator, CloseableCHColumnBatchIterator} import org.apache.spark.{InterruptibleIterator, SparkConf, TaskContext} import org.apache.spark.affinity.CHAffinity @@ -322,17 +322,6 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil { createNativeIterator(splitInfoByteArray, wsPlan, materializeInput, inputIterators)) } - /** - * This function used to inject the staging write path before initializing the native plan.Only - * used in a pipeline model (spark 3.5) for writing parquet or orc files. - */ - override def injectWriteFilesTempPath(path: String, fileName: String): Unit = { - val settings = - Map( - RuntimeSettings.TASK_WRITE_TMP_DIR.key -> path, - RuntimeSettings.TASK_WRITE_FILENAME.key -> fileName) - NativeExpressionEvaluator.updateQueryRuntimeSettings(settings) - } } class CollectMetricIterator( diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala index edf7a4802535e..98cfa0e7547b6 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/CHRuleApi.scala @@ -79,6 +79,7 @@ object CHRuleApi { injector.injectPreTransform(_ => RewriteSubqueryBroadcast()) injector.injectPreTransform(c => FallbackBroadcastHashJoin.apply(c.session)) injector.injectPreTransform(c => MergeTwoPhasesHashBaseAggregate.apply(c.session)) + injector.injectPreTransform(_ => WriteFilesWithBucketValue) // Legacy: The legacy transform rule. val validatorBuilder: GlutenConfig => Validator = conf => diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/RuntimeSettings.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/RuntimeSettings.scala index b59bb32392dfc..c2747cf1eb538 100644 --- a/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/RuntimeSettings.scala +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/backendsapi/clickhouse/RuntimeSettings.scala @@ -35,12 +35,6 @@ object RuntimeSettings { .stringConf .createWithDefault("") - val TASK_WRITE_FILENAME = - buildConf(runtimeSettings("gluten.task_write_filename")) - .doc("The temporary file name for writing data") - .stringConf - .createWithDefault("") - val TASK_WRITE_FILENAME_PATTERN = buildConf(runtimeSettings("gluten.task_write_filename_pattern")) .doc("The pattern to generate file name for writing delta parquet in spark 3.5") diff --git a/backends-clickhouse/src/main/scala/org/apache/gluten/extension/WriteFilesWithBucketValue.scala b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/WriteFilesWithBucketValue.scala new file mode 100644 index 0000000000000..8ab78dcff9bce --- /dev/null +++ b/backends-clickhouse/src/main/scala/org/apache/gluten/extension/WriteFilesWithBucketValue.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.gluten.extension + +import org.apache.gluten.GlutenConfig + +import org.apache.spark.sql.catalyst.expressions.{Alias, BitwiseAnd, Expression, HiveHash, Literal, Pmod} +import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{ProjectExec, SparkPlan} +import org.apache.spark.sql.execution.datasources.WriteFilesExec + +/** + * Wrap with bucket value to specify the bucket file name in native write. Native writer will remove + * this value in the final output. + */ +object WriteFilesWithBucketValue extends Rule[SparkPlan] { + + val optionForHiveCompatibleBucketWrite = "__hive_compatible_bucketed_table_insertion__" + + override def apply(plan: SparkPlan): SparkPlan = { + if ( + GlutenConfig.getConf.enableGluten + && GlutenConfig.getConf.enableNativeWriter.getOrElse(false) + ) { + plan.transformDown { + case writeFiles: WriteFilesExec if writeFiles.bucketSpec.isDefined => + val bucketIdExp = getWriterBucketIdExp(writeFiles) + val wrapBucketValue = ProjectExec( + writeFiles.child.output :+ Alias(bucketIdExp, "__bucket_value__")(), + writeFiles.child) + writeFiles.copy(child = wrapBucketValue) + } + } else { + plan + } + } + + private def getWriterBucketIdExp(writeFilesExec: WriteFilesExec): Expression = { + val partitionColumns = writeFilesExec.partitionColumns + val outputColumns = writeFilesExec.child.output + val dataColumns = outputColumns.filterNot(partitionColumns.contains) + val bucketSpec = writeFilesExec.bucketSpec.get + val bucketColumns = bucketSpec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get) + if (writeFilesExec.options.getOrElse(optionForHiveCompatibleBucketWrite, "false") == "true") { + val hashId = BitwiseAnd(HiveHash(bucketColumns), Literal(Int.MaxValue)) + Pmod(hashId, Literal(bucketSpec.numBuckets)) + // The bucket file name prefix is following Hive, Presto and Trino conversion, so this + // makes sure Hive bucketed table written by Spark, can be read by other SQL engines. + // + // Hive: `org.apache.hadoop.hive.ql.exec.Utilities#getBucketIdFromFile()`. + // Trino: `io.trino.plugin.hive.BackgroundHiveSplitLoader#BUCKET_PATTERNS`. + + } else { + // Spark bucketed table: use `HashPartitioning.partitionIdExpression` as bucket id + // expression, so that we can guarantee the data distribution is same between shuffle and + // bucketed data source, which enables us to only shuffle one side when join a bucketed + // table and a normal one. + HashPartitioning(bucketColumns, bucketSpec.numBuckets).partitionIdExpression + } + } +} diff --git a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/CHColumnarWrite.scala b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/CHColumnarWrite.scala index 6c7877cc020bd..1342e250430ee 100644 --- a/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/CHColumnarWrite.scala +++ b/backends-clickhouse/src/main/scala/org/apache/spark/sql/execution/CHColumnarWrite.scala @@ -16,7 +16,8 @@ */ package org.apache.spark.sql.execution -import org.apache.gluten.backendsapi.BackendsApiManager +import org.apache.gluten.backendsapi.clickhouse.RuntimeSettings +import org.apache.gluten.vectorized.NativeExpressionEvaluator import org.apache.spark.TaskContext import org.apache.spark.internal.Logging @@ -25,11 +26,11 @@ import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.delta.stats.DeltaJobStatisticsTracker -import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, BasicWriteTaskStats, ExecutedWriteSummary, PartitioningUtils, WriteJobDescription, WriteTaskResult, WriteTaskStatsTracker} +import org.apache.spark.sql.execution.datasources._ import org.apache.spark.util.Utils import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapreduce.{JobID, OutputCommitter, TaskAttemptContext, TaskAttemptID, TaskID, TaskType} +import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl @@ -102,6 +103,12 @@ object CreateFileNameSpec { } } +// More details in local_engine::FileNameGenerator in NormalFileWriter.cpp +object FileNamePlaceHolder { + val ID = "{id}" + val BUCKET = "{bucket}" +} + /** [[HadoopMapReduceAdapter]] for [[HadoopMapReduceCommitProtocol]]. */ case class HadoopMapReduceAdapter(sparkCommitter: HadoopMapReduceCommitProtocol) { private lazy val committer: OutputCommitter = { @@ -132,12 +139,26 @@ case class HadoopMapReduceAdapter(sparkCommitter: HadoopMapReduceCommitProtocol) GetFilename.invoke(sparkCommitter, taskContext, spec).asInstanceOf[String] } - def getTaskAttemptTempPathAndFilename( + def getTaskAttemptTempPathAndFilePattern( taskContext: TaskAttemptContext, description: WriteJobDescription): (String, String) = { val stageDir = newTaskAttemptTempPath(description.path) - val filename = getFilename(taskContext, CreateFileNameSpec(taskContext, description)) - (stageDir, filename) + + if (isBucketWrite(description)) { + val filePart = getFilename(taskContext, FileNameSpec("", "")) + val fileSuffix = CreateFileNameSpec(taskContext, description).suffix + (stageDir, s"${filePart}_${FileNamePlaceHolder.BUCKET}$fileSuffix") + } else { + val filename = getFilename(taskContext, CreateFileNameSpec(taskContext, description)) + (stageDir, filename) + } + } + + private def isBucketWrite(desc: WriteJobDescription): Boolean = { + // In Spark 3.2, bucketSpec is not defined, instead, it uses bucketIdExpression. + val bucketSpecField: Field = desc.getClass.getDeclaredField("bucketSpec") + bucketSpecField.setAccessible(true) + bucketSpecField.get(desc).asInstanceOf[Option[_]].isDefined } } @@ -234,10 +255,15 @@ case class HadoopMapReduceCommitProtocolWrite( * initializing the native plan and collect native write files metrics for each backend. */ override def doSetupNativeTask(): Unit = { - val (writePath, writeFileName) = - adapter.getTaskAttemptTempPathAndFilename(taskAttemptContext, description) - logDebug(s"Native staging write path: $writePath and file name: $writeFileName") - BackendsApiManager.getIteratorApiInstance.injectWriteFilesTempPath(writePath, writeFileName) + val (writePath, writeFilePattern) = + adapter.getTaskAttemptTempPathAndFilePattern(taskAttemptContext, description) + logDebug(s"Native staging write path: $writePath and file pattern: $writeFilePattern") + + val settings = + Map( + RuntimeSettings.TASK_WRITE_TMP_DIR.key -> writePath, + RuntimeSettings.TASK_WRITE_FILENAME_PATTERN.key -> writeFilePattern) + NativeExpressionEvaluator.updateQueryRuntimeSettings(settings) } def doCollectNativeResult(stats: Seq[InternalRow]): Option[WriteTaskResult] = { diff --git a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseNativeWriteTableSuite.scala b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseNativeWriteTableSuite.scala index 03d27f33b1e13..16ed302a02f41 100644 --- a/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseNativeWriteTableSuite.scala +++ b/backends-clickhouse/src/test/scala/org/apache/gluten/execution/hive/GlutenClickHouseNativeWriteTableSuite.scala @@ -553,7 +553,7 @@ class GlutenClickHouseNativeWriteTableSuite // spark write does not support bucketed table // https://issues.apache.org/jira/browse/SPARK-19256 val table_name = table_name_template.format(format) - writeAndCheckRead(origin_table, table_name, fields_.keys.toSeq, isSparkVersionLE("3.3")) { + writeAndCheckRead(origin_table, table_name, fields_.keys.toSeq) { fields => spark .table("origin_table") @@ -589,8 +589,9 @@ class GlutenClickHouseNativeWriteTableSuite ("byte_field", "byte"), ("boolean_field", "boolean"), ("decimal_field", "decimal(23,12)"), - ("date_field", "date"), - ("timestamp_field", "timestamp") + ("date_field", "date") + // ("timestamp_field", "timestamp") + // FIXME https://github.com/apache/incubator-gluten/issues/8053 ) val origin_table = "origin_table" withSource(genTestData(), origin_table) { @@ -598,7 +599,7 @@ class GlutenClickHouseNativeWriteTableSuite format => val table_name = table_name_template.format(format) val testFields = fields.keys.toSeq - writeAndCheckRead(origin_table, table_name, testFields, isSparkVersionLE("3.3")) { + writeAndCheckRead(origin_table, table_name, testFields) { fields => spark .table(origin_table) @@ -658,7 +659,7 @@ class GlutenClickHouseNativeWriteTableSuite nativeWrite { format => val table_name = table_name_template.format(format) - writeAndCheckRead(origin_table, table_name, fields.keys.toSeq, isSparkVersionLE("3.3")) { + writeAndCheckRead(origin_table, table_name, fields.keys.toSeq) { fields => spark .table("origin_table") @@ -762,7 +763,7 @@ class GlutenClickHouseNativeWriteTableSuite format => val table_name = table_name_template.format(format) spark.sql(s"drop table IF EXISTS $table_name") - withNativeWriteCheck(checkNative = isSparkVersionLE("3.3")) { + withNativeWriteCheck(checkNative = true) { spark .range(10000000) .selectExpr("id", "cast('2020-01-01' as date) as p") @@ -798,7 +799,7 @@ class GlutenClickHouseNativeWriteTableSuite format => val table_name = table_name_template.format(format) spark.sql(s"drop table IF EXISTS $table_name") - withNativeWriteCheck(checkNative = isSparkVersionLE("3.3")) { + withNativeWriteCheck(checkNative = true) { spark .range(30000) .selectExpr("id", "cast(null as string) as p") diff --git a/cpp-ch/local-engine/Parser/RelParsers/WriteRelParser.cpp b/cpp-ch/local-engine/Parser/RelParsers/WriteRelParser.cpp index 2dacb39918588..a76b4d398d975 100644 --- a/cpp-ch/local-engine/Parser/RelParsers/WriteRelParser.cpp +++ b/cpp-ch/local-engine/Parser/RelParsers/WriteRelParser.cpp @@ -56,10 +56,11 @@ DB::ProcessorPtr make_sink( const std::string & format_hint, const std::shared_ptr & stats) { - if (partition_by.empty()) + bool no_bucketed = !SparkPartitionedBaseSink::isBucketedWrite(input_header); + if (partition_by.empty() && no_bucketed) { return std::make_shared( - context, base_path, "", generator.generate(), format_hint, input_header, stats, DeltaStats{input_header.columns()}); + context, base_path, "", false, generator.generate(), format_hint, input_header, stats, DeltaStats{input_header.columns()}); } return std::make_shared( @@ -184,13 +185,10 @@ void addNormalFileWriterSinkTransform( if (write_settings.task_write_tmp_dir.empty()) throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Write Pipeline need inject temp directory."); - if (write_settings.task_write_filename.empty() && write_settings.task_write_filename_pattern.empty()) - throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Write Pipeline need inject file name or file name pattern."); + if (write_settings.task_write_filename_pattern.empty()) + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Write Pipeline need inject file pattern."); - FileNameGenerator generator{ - .pattern = write_settings.task_write_filename.empty(), - .filename_or_pattern - = write_settings.task_write_filename.empty() ? write_settings.task_write_filename_pattern : write_settings.task_write_filename}; + FileNameGenerator generator(write_settings.task_write_filename_pattern); auto stats = WriteStats::create(output, partitionCols); diff --git a/cpp-ch/local-engine/Parser/RelParsers/WriteRelParser.h b/cpp-ch/local-engine/Parser/RelParsers/WriteRelParser.h index 0c9bc11f1f103..01e0dabaaa7df 100644 --- a/cpp-ch/local-engine/Parser/RelParsers/WriteRelParser.h +++ b/cpp-ch/local-engine/Parser/RelParsers/WriteRelParser.h @@ -44,7 +44,6 @@ DB::Names collect_partition_cols(const DB::Block & header, const substrait::Name #define WRITE_RELATED_SETTINGS(M, ALIAS) \ M(String, task_write_tmp_dir, , "The temporary directory for writing data") \ - M(String, task_write_filename, , "The filename for writing data") \ M(String, task_write_filename_pattern, , "The pattern to generate file name for writing delta parquet in spark 3.5") DECLARE_GLUTEN_SETTINGS(GlutenWriteSettings, WRITE_RELATED_SETTINGS) diff --git a/cpp-ch/local-engine/Storages/MergeTree/SparkMergeTreeSink.h b/cpp-ch/local-engine/Storages/MergeTree/SparkMergeTreeSink.h index 38f574ea9872e..b551d86d1d0cc 100644 --- a/cpp-ch/local-engine/Storages/MergeTree/SparkMergeTreeSink.h +++ b/cpp-ch/local-engine/Storages/MergeTree/SparkMergeTreeSink.h @@ -278,6 +278,12 @@ class SparkMergeTreePartitionedFileSink final : public SparkPartitionedBaseSink return SparkMergeTreeSink::create( table, write_settings, context_->getGlobalContext(), {std::dynamic_pointer_cast(stats_)}); } + + // TODO implement with bucket + DB::SinkPtr createSinkForPartition(const String & partition_id, const String & bucket) override + { + return createSinkForPartition(partition_id); + } }; } diff --git a/cpp-ch/local-engine/Storages/Output/NormalFileWriter.cpp b/cpp-ch/local-engine/Storages/Output/NormalFileWriter.cpp index ad2e3abf7b527..2d70380a8959e 100644 --- a/cpp-ch/local-engine/Storages/Output/NormalFileWriter.cpp +++ b/cpp-ch/local-engine/Storages/Output/NormalFileWriter.cpp @@ -30,6 +30,8 @@ using namespace DB; const std::string SubstraitFileSink::NO_PARTITION_ID{"__NO_PARTITION_ID__"}; const std::string SparkPartitionedBaseSink::DEFAULT_PARTITION_NAME{"__HIVE_DEFAULT_PARTITION__"}; +const std::string SparkPartitionedBaseSink::BUCKET_COLUMN_NAME{"__bucket_value__"}; +const std::vector FileNameGenerator::SUPPORT_PLACEHOLDERS{"{id}", "{bucket}"}; /// For Nullable(Map(K, V)) or Nullable(Array(T)), if the i-th row is null, we must make sure its nested data is empty. /// It is for ORC/Parquet writing compatiability. For more details, refer to @@ -168,7 +170,7 @@ void NormalFileWriter::write(DB::Block & block) const auto & preferred_schema = file->getPreferredSchema(); for (auto & column : block) { - if (column.name.starts_with("__bucket_value__")) + if (column.name.starts_with(SparkPartitionedBaseSink::BUCKET_COLUMN_NAME)) continue; const auto & preferred_column = preferred_schema.getByPosition(index++); diff --git a/cpp-ch/local-engine/Storages/Output/NormalFileWriter.h b/cpp-ch/local-engine/Storages/Output/NormalFileWriter.h index 8cfe079d92c58..998f8d624721a 100644 --- a/cpp-ch/local-engine/Storages/Output/NormalFileWriter.h +++ b/cpp-ch/local-engine/Storages/Output/NormalFileWriter.h @@ -230,20 +230,57 @@ class WriteStats : public WriteStatsBase struct FileNameGenerator { - const bool pattern; - const std::string filename_or_pattern; + // Align with org.apache.spark.sql.execution.FileNamePlaceHolder + static const std::vector SUPPORT_PLACEHOLDERS; + // Align with placeholders above + const std::vector need_to_replace; + const std::string file_pattern; + + FileNameGenerator(const std::string & file_pattern) + : file_pattern(file_pattern), need_to_replace(compute_need_to_replace(file_pattern)) + { + } + + std::vector compute_need_to_replace(const std::string & file_pattern) + { + std::vector result; + for(const std::string& placeholder: SUPPORT_PLACEHOLDERS) + { + if (file_pattern.find(placeholder) != std::string::npos) + result.push_back(true); + else + result.push_back(false); + } + return result; + } + + std::string generate(const std::string & bucket = "") const + { + std::string result = file_pattern; + if (need_to_replace[0]) // {id} + result = pattern_format(SUPPORT_PLACEHOLDERS[0], toString(DB::UUIDHelpers::generateV4())); + if (need_to_replace[1]) // {bucket} + result = pattern_format(SUPPORT_PLACEHOLDERS[1], bucket); + return result; + } - std::string generate() const + std::string pattern_format(const std::string & arg, const std::string & replacement) const { - if (pattern) - return fmt::vformat(filename_or_pattern, fmt::make_format_args(toString(DB::UUIDHelpers::generateV4()))); - return filename_or_pattern; + std::string format_str = file_pattern; + size_t pos = format_str.find(arg); + while (pos != std::string::npos) + { + format_str.replace(pos, arg.length(), replacement); + pos = format_str.find(arg, pos + arg.length()); + } + return format_str; } }; class SubstraitFileSink final : public DB::SinkToStorage { const std::string partition_id_; + const bool bucketed_write_; const std::string relative_path_; OutputFormatFilePtr format_file_; OutputFormatFile::OutputFormatPtr output_format_; @@ -265,6 +302,7 @@ class SubstraitFileSink final : public DB::SinkToStorage const DB::ContextPtr & context, const std::string & base_path, const std::string & partition_id, + const bool bucketed_write, const std::string & relative, const std::string & format_hint, const DB::Block & header, @@ -272,6 +310,7 @@ class SubstraitFileSink final : public DB::SinkToStorage const DeltaStats & delta_stats) : SinkToStorage(header) , partition_id_(partition_id.empty() ? NO_PARTITION_ID : partition_id) + , bucketed_write_(bucketed_write) , relative_path_(relative) , format_file_(createOutputFormatFile(context, makeAbsoluteFilename(base_path, partition_id, relative), header, format_hint)) , stats_(std::dynamic_pointer_cast(stats)) @@ -287,7 +326,18 @@ class SubstraitFileSink final : public DB::SinkToStorage delta_stats_.update(chunk); if (!output_format_) [[unlikely]] output_format_ = format_file_->createOutputFormat(); - output_format_->output->write(materializeBlock(getHeader().cloneWithColumns(chunk.detachColumns()))); + + const DB::Block & input_header = getHeader(); + if (bucketed_write_) + { + chunk.erase(input_header.columns() - 1); + const DB::ColumnsWithTypeAndName & cols = input_header.getColumnsWithTypeAndName(); + DB::ColumnsWithTypeAndName without_bucket_cols(cols.begin(), cols.end() - 1); + DB::Block without_bucket_header = DB::Block(without_bucket_cols); + output_format_->output->write(materializeBlock(without_bucket_header.cloneWithColumns(chunk.detachColumns()))); + } + else + output_format_->output->write(materializeBlock(input_header.cloneWithColumns(chunk.detachColumns()))); } void onFinish() override { @@ -303,11 +353,19 @@ class SubstraitFileSink final : public DB::SinkToStorage class SparkPartitionedBaseSink : public DB::PartitionedSink { - static const std::string DEFAULT_PARTITION_NAME; public: + static const std::string DEFAULT_PARTITION_NAME; + static const std::string BUCKET_COLUMN_NAME; + + static bool isBucketedWrite(const DB::Block & input_header) + { + return input_header.has(BUCKET_COLUMN_NAME) && + input_header.getPositionByName(BUCKET_COLUMN_NAME) == input_header.columns() - 1; + } + /// visible for UTs - static DB::ASTPtr make_partition_expression(const DB::Names & partition_columns) + static DB::ASTPtr make_partition_expression(const DB::Names & partition_columns, const DB::Block & input_header) { /// Parse the following expression into ASTs /// cancat('/col_name=', 'toString(col_name)') @@ -327,13 +385,35 @@ class SparkPartitionedBaseSink : public DB::PartitionedSink makeASTFunction("toString", DB::ASTs{column_ast}), std::make_shared(DEFAULT_PARTITION_NAME)}; arguments.emplace_back(makeASTFunction("ifNull", std::move(if_null_args))); } + if (isBucketedWrite(input_header)) + { + DB::ASTs args {std::make_shared("%05d"), std::make_shared(BUCKET_COLUMN_NAME)}; + arguments.emplace_back(DB::makeASTFunction("printf", std::move(args))); + } + assert(!arguments.empty()); + if (arguments.size() == 1) + return arguments[0]; return DB::makeASTFunction("concat", std::move(arguments)); } + DB::SinkPtr createSinkForPartition(const String & partition_id) override + { + if (bucketed_write_) + { + std::string bucket_val = partition_id.substr(partition_id.length() - 5, 5); + std::string real_partition_id = partition_id.substr(0, partition_id.length() - 5); + return createSinkForPartition(real_partition_id, bucket_val); + } + return createSinkForPartition(partition_id, ""); + } + + virtual DB::SinkPtr createSinkForPartition(const String & partition_id, const String & bucket) = 0; + protected: DB::ContextPtr context_; std::shared_ptr stats_; DeltaStats empty_delta_stats_; + bool bucketed_write_; public: SparkPartitionedBaseSink( @@ -341,9 +421,10 @@ class SparkPartitionedBaseSink : public DB::PartitionedSink const DB::Names & partition_by, const DB::Block & input_header, const std::shared_ptr & stats) - : PartitionedSink(make_partition_expression(partition_by), context, input_header) + : PartitionedSink(make_partition_expression(partition_by, input_header), context, input_header) , context_(context) , stats_(stats) + , bucketed_write_(isBucketedWrite(input_header)) , empty_delta_stats_(DeltaStats::create(input_header, partition_by)) { } @@ -353,6 +434,7 @@ class SubstraitPartitionedFileSink final : public SparkPartitionedBaseSink { const std::string base_path_; const FileNameGenerator generator_; + const DB::Block input_header_; const DB::Block sample_block_; const std::string format_hint_; @@ -370,18 +452,20 @@ class SubstraitPartitionedFileSink final : public SparkPartitionedBaseSink , base_path_(base_path) , generator_(generator) , sample_block_(sample_block) + , input_header_(input_header) , format_hint_(format_hint) { } - DB::SinkPtr createSinkForPartition(const String & partition_id) override + DB::SinkPtr createSinkForPartition(const String & partition_id, const String & bucket) override { assert(stats_); - std::string filename = generator_.generate(); + bool bucketed_write = !bucket.empty(); + std::string filename = bucketed_write ? generator_.generate(bucket) : generator_.generate(); const auto partition_path = fmt::format("{}/{}", partition_id, filename); validatePartitionKey(partition_path, true); return std::make_shared( - context_, base_path_, partition_id, filename, format_hint_, sample_block_, stats_, empty_delta_stats_); + context_, base_path_, partition_id, bucketed_write, filename, format_hint_, sample_block_, stats_, empty_delta_stats_); } String getName() const override { return "SubstraitPartitionedFileSink"; } }; diff --git a/cpp-ch/local-engine/Storages/Output/OutputFormatFile.cpp b/cpp-ch/local-engine/Storages/Output/OutputFormatFile.cpp index 194d997ddf827..d5ed4309439c3 100644 --- a/cpp-ch/local-engine/Storages/Output/OutputFormatFile.cpp +++ b/cpp-ch/local-engine/Storages/Output/OutputFormatFile.cpp @@ -61,7 +61,6 @@ Block OutputFormatFile::createHeaderWithPreferredSchema(const Block & header) ColumnWithTypeAndName column(preferred_column.type->createColumn(), preferred_column.type, preferred_column.name); columns.emplace_back(std::move(column)); } - assert(preferred_schema.columns() == index); return {std::move(columns)}; } diff --git a/cpp-ch/local-engine/tests/gtest_write_pipeline.cpp b/cpp-ch/local-engine/tests/gtest_write_pipeline.cpp index 00f2da20c5224..b764f62f542e2 100644 --- a/cpp-ch/local-engine/tests/gtest_write_pipeline.cpp +++ b/cpp-ch/local-engine/tests/gtest_write_pipeline.cpp @@ -206,14 +206,14 @@ TEST(WritePipeline, SubstraitPartitionedFileSink) TEST(WritePipeline, ComputePartitionedExpression) { const auto context = DB::Context::createCopy(QueryContext::globalContext()); - - auto partition_by = SubstraitPartitionedFileSink::make_partition_expression({"s_nationkey", "name"}); + + Block sample_block{{STRING(), "name"}, {UINT(), "s_nationkey"}}; + auto partition_by = SubstraitPartitionedFileSink::make_partition_expression({"s_nationkey", "name"}, sample_block); // auto partition_by = printColumn("s_nationkey"); ASTs arguments(1, partition_by); ASTPtr partition_by_string = makeASTFunction("toString", std::move(arguments)); - Block sample_block{{STRING(), "name"}, {UINT(), "s_nationkey"}}; auto syntax_result = TreeRewriter(context).analyze(partition_by_string, sample_block.getNamesAndTypesList()); auto partition_by_expr = ExpressionAnalyzer(partition_by_string, syntax_result, context).getActions(false);