diff --git a/gluten-core/src/main/scala/org/apache/gluten/planner/property/GlutenPropertyModel.scala b/gluten-core/src/main/scala/org/apache/gluten/planner/property/GlutenPropertyModel.scala index 54f4e3b84244..07dd3fe0284f 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/planner/property/GlutenPropertyModel.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/planner/property/GlutenPropertyModel.scala @@ -53,6 +53,8 @@ object GlutenProperties { val conv = getProperty(plan) plan.children.map(_ => conv) } + + override def any(): Convention = Conventions.ANY } case class ConventionEnforcerRule(reqConv: Convention) extends RasRule[SparkPlan] { diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/MetadataModel.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/MetadataModel.scala index d2056746c398..a81ac31cba58 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/MetadataModel.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/MetadataModel.scala @@ -21,8 +21,9 @@ package org.apache.gluten.ras */ trait MetadataModel[T <: AnyRef] { def metadataOf(node: T): Metadata - def dummy(): Metadata def verify(one: Metadata, other: Metadata): Unit + + def dummy(): Metadata } trait Metadata {} diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/PropertyModel.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/PropertyModel.scala index e2ba99136749..e764631e7777 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/PropertyModel.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/PropertyModel.scala @@ -26,6 +26,7 @@ trait Property[T <: AnyRef] { } trait PropertyDef[T <: AnyRef, P <: Property[T]] { + def any(): P def getProperty(plan: T): P def getChildrenConstraints(constraint: Property[T], plan: T): Seq[P] } diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala index 6832d07c5790..9910fab6f000 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala @@ -79,9 +79,10 @@ class Ras[T <: AnyRef] private ( ruleFactory) } - // Normal groups start with ID 0, so it's safe to use -1 to do validation. + private val propSetFactory: PropertySetFactory[T] = PropertySetFactory(propertyModel, planModel) + // Normal groups start with ID 0, so it's safe to use Int.MinValue to do validation. private val dummyGroup: T = - planModel.newGroupLeaf(-1, metadataModel.dummy(), PropertySet(Seq.empty)) + planModel.newGroupLeaf(Int.MinValue, metadataModel.dummy(), propSetFactory.any()) private val infCost: Cost = costModel.makeInfCost() validateModels() @@ -123,8 +124,6 @@ class Ras[T <: AnyRef] private ( } } - private val propSetFactory: PropertySetFactory[T] = PropertySetFactory(this) - override def newPlanner( plan: T, constraintSet: PropertySet[T], @@ -171,6 +170,8 @@ class Ras[T <: AnyRef] private ( private[ras] def getInfCost(): Cost = infCost private[ras] def isInfCost(cost: Cost) = costModel.costComparator().equiv(cost, infCost) + + private[ras] def toUnsafeKey(node: T): UnsafeKey[T] = UnsafeKey(this, node) } object Ras { @@ -192,16 +193,29 @@ object Ras { } trait PropertySetFactory[T <: AnyRef] { + def any(): PropertySet[T] def get(node: T): PropertySet[T] def childrenConstraintSets(constraintSet: PropertySet[T], node: T): Seq[PropertySet[T]] } private object PropertySetFactory { - def apply[T <: AnyRef](ras: Ras[T]): PropertySetFactory[T] = new PropertySetFactoryImpl[T](ras) - - private class PropertySetFactoryImpl[T <: AnyRef](val ras: Ras[T]) + def apply[T <: AnyRef]( + propertyModel: PropertyModel[T], + planModel: PlanModel[T]): PropertySetFactory[T] = + new PropertySetFactoryImpl[T](propertyModel, planModel) + + private class PropertySetFactoryImpl[T <: AnyRef]( + propertyModel: PropertyModel[T], + planModel: PlanModel[T]) extends PropertySetFactory[T] { - private val propDefs: Seq[PropertyDef[T, _ <: Property[T]]] = ras.propertyModel.propertyDefs + private val propDefs: Seq[PropertyDef[T, _ <: Property[T]]] = propertyModel.propertyDefs + private val anyConstraint = { + val m: Map[PropertyDef[T, _ <: Property[T]], Property[T]] = + propDefs.map(propDef => (propDef, propDef.any())).toMap + PropertySet[T](m) + } + + override def any(): PropertySet[T] = anyConstraint override def get(node: T): PropertySet[T] = { val m: Map[PropertyDef[T, _ <: Property[T]], Property[T]] = @@ -213,7 +227,7 @@ object Ras { constraintSet: PropertySet[T], node: T): Seq[PropertySet[T]] = { val builder: Seq[mutable.Map[PropertyDef[T, _ <: Property[T]], Property[T]]] = - ras.planModel + planModel .childrenOf(node) .map(_ => mutable.Map[PropertyDef[T, _ <: Property[T]], Property[T]]()) @@ -236,4 +250,20 @@ object Ras { } } } + + trait UnsafeKey[T] + + private object UnsafeKey { + def apply[T <: AnyRef](ras: Ras[T], self: T): UnsafeKey[T] = new UnsafeKeyImpl(ras, self) + private class UnsafeKeyImpl[T <: AnyRef](ras: Ras[T], val self: T) extends UnsafeKey[T] { + override def hashCode(): Int = ras.planModel.hashCode(self) + override def equals(other: Any): Boolean = { + other match { + case that: UnsafeKeyImpl[T] => ras.planModel.equals(self, that.self) + case _ => false + } + } + override def toString: String = ras.explain.describeNode(self) + } + } } diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala index 63b8b1e68273..1b30e1242c82 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasCluster.scala @@ -16,6 +16,7 @@ */ package org.apache.gluten.ras +import org.apache.gluten.ras.Ras.UnsafeKey import org.apache.gluten.ras.memo.MemoTable import org.apache.gluten.ras.property.PropertySet @@ -54,16 +55,19 @@ object RasCluster { override val ras: Ras[T], metadata: Metadata) extends MutableRasCluster[T] { - private val buffer: mutable.Set[CanonicalNode[T]] = - mutable.Set() + private val deDup: mutable.Set[UnsafeKey[T]] = mutable.Set() + private val buffer: mutable.ListBuffer[CanonicalNode[T]] = + mutable.ListBuffer() override def contains(t: CanonicalNode[T]): Boolean = { - buffer.contains(t) + deDup.contains(t.toUnsafeKey()) } override def add(t: CanonicalNode[T]): Unit = { + val key = t.toUnsafeKey() + assert(!deDup.contains(key)) ras.metadataModel.verify(metadata, ras.metadataModel.metadataOf(t.self())) - assert(!buffer.contains(t)) + deDup += key buffer += t } @@ -75,12 +79,12 @@ object RasCluster { case class ImmutableRasCluster[T <: AnyRef] private ( ras: Ras[T], - override val nodes: Set[CanonicalNode[T]]) + override val nodes: Seq[CanonicalNode[T]]) extends RasCluster[T] object ImmutableRasCluster { def apply[T <: AnyRef](ras: Ras[T], cluster: RasCluster[T]): ImmutableRasCluster[T] = { - ImmutableRasCluster(ras, cluster.nodes().toSet) + ImmutableRasCluster(ras, cluster.nodes().toVector) } } } diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala index 5f18f96a7ab8..878020391c46 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasNode.scala @@ -16,6 +16,7 @@ */ package org.apache.gluten.ras +import org.apache.gluten.ras.Ras.UnsafeKey import org.apache.gluten.ras.property.PropertySet trait RasNode[T <: AnyRef] { @@ -41,6 +42,8 @@ object RasNode { def asGroup(): GroupNode[T] = { node.asInstanceOf[GroupNode[T]] } + + def toUnsafeKey(): UnsafeKey[T] = node.ras().toUnsafeKey(node.self()) } } @@ -53,7 +56,7 @@ object CanonicalNode { assert(ras.isCanonical(canonical)) val propSet = ras.propSetsOf(canonical) val children = ras.planModel.childrenOf(canonical) - CanonicalNodeImpl[T](ras, canonical, propSet, children.size) + new CanonicalNodeImpl[T](ras, canonical, propSet, children.size) } // We put RasNode's API methods that accept mutable input in implicit definition. @@ -74,12 +77,16 @@ object CanonicalNode { } } - private case class CanonicalNodeImpl[T <: AnyRef]( - ras: Ras[T], + private class CanonicalNodeImpl[T <: AnyRef]( + override val ras: Ras[T], override val self: T, override val propSet: PropertySet[T], override val childrenCount: Int) - extends CanonicalNode[T] + extends CanonicalNode[T] { + override def toString: String = ras.explain.describeNode(self) + override def hashCode(): Int = throw new UnsupportedOperationException() + override def equals(obj: Any): Boolean = throw new UnsupportedOperationException() + } } trait GroupNode[T <: AnyRef] extends RasNode[T] { @@ -88,15 +95,19 @@ trait GroupNode[T <: AnyRef] extends RasNode[T] { object GroupNode { def apply[T <: AnyRef](ras: Ras[T], group: RasGroup[T]): GroupNode[T] = { - GroupNodeImpl[T](ras, group.self(), group.propSet(), group.id()) + new GroupNodeImpl[T](ras, group.self(), group.propSet(), group.id()) } - private case class GroupNodeImpl[T <: AnyRef]( - ras: Ras[T], + private class GroupNodeImpl[T <: AnyRef]( + override val ras: Ras[T], override val self: T, override val propSet: PropertySet[T], override val groupId: Int) - extends GroupNode[T] {} + extends GroupNode[T] { + override def toString: String = ras.explain.describeNode(self) + override def hashCode(): Int = throw new UnsupportedOperationException() + override def equals(obj: Any): Boolean = throw new UnsupportedOperationException() + } // We put RasNode's API methods that accept mutable input in implicit definition. // Do not break this rule during further development. @@ -116,8 +127,21 @@ object InGroupNode { def apply[T <: AnyRef](groupId: Int, node: CanonicalNode[T]): InGroupNode[T] = { InGroupNodeImpl(groupId, node) } + private case class InGroupNodeImpl[T <: AnyRef](groupId: Int, can: CanonicalNode[T]) extends InGroupNode[T] + + trait HashKey extends Any + + implicit class InGroupNodeImplicits[T <: AnyRef](n: InGroupNode[T]) { + import InGroupNodeImplicits._ + def toHashKey: HashKey = + InGroupNodeHashKeyImpl(n.groupId, System.identityHashCode(n.can)) + } + + private object InGroupNodeImplicits { + private case class InGroupNodeHashKeyImpl(gid: Int, cid: Int) extends HashKey + } } trait InClusterNode[T <: AnyRef] { @@ -129,8 +153,21 @@ object InClusterNode { def apply[T <: AnyRef](clusterId: RasClusterKey, node: CanonicalNode[T]): InClusterNode[T] = { InClusterNodeImpl(clusterId, node) } + private case class InClusterNodeImpl[T <: AnyRef]( clusterKey: RasClusterKey, can: CanonicalNode[T]) extends InClusterNode[T] + + trait HashKey extends Any + + implicit class InClusterNodeImplicits[T <: AnyRef](n: InClusterNode[T]) { + import InClusterNodeImplicits._ + def toHashKey: HashKey = + InClusterNodeHashKeyImpl(n.clusterKey, System.identityHashCode(n.can)) + } + + private object InClusterNodeImplicits { + private case class InClusterNodeHashKeyImpl(clusterKey: RasClusterKey, cid: Int) extends HashKey + } } diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasPlanner.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasPlanner.scala index 0665d3661dd6..74793a3d0fbc 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasPlanner.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/RasPlanner.scala @@ -49,8 +49,8 @@ object RasPlanner { trait Best[T <: AnyRef] { import Best._ def rootGroupId(): Int - def bestNodes(): Set[InGroupNode[T]] - def winnerNodes(): Set[InGroupNode[T]] + def bestNodes(): InGroupNode[T] => Boolean + def winnerNodes(): InGroupNode[T] => Boolean def costs(): InGroupNode[T] => Option[Cost] def path(): KnownCostPath[T] } @@ -62,11 +62,11 @@ object Best { bestPath: KnownCostPath[T], winnerNodes: Seq[InGroupNode[T]], costs: InGroupNode[T] => Option[Cost]): Best[T] = { - val bestNodes = mutable.Set[InGroupNode[T]]() + val bestNodes = mutable.Set[InGroupNode.HashKey]() def dfs(groupId: Int, cursor: RasPath.PathNode[T]): Unit = { val can = cursor.self().asCanonical() - bestNodes += InGroupNode(groupId, can) + bestNodes += InGroupNode(groupId, can).toHashKey cursor.zipChildrenWithGroupIds().foreach { case (childPathNode, childGroupId) => dfs(childGroupId, childPathNode) @@ -75,17 +75,24 @@ object Best { dfs(rootGroupId, bestPath.rasPath.node()) - val winnerNodeSet = winnerNodes.toSet + val bestNodeSet = bestNodes.toSet + val winnerNodeSet = winnerNodes.map(_.toHashKey).toSet - BestImpl(ras, rootGroupId, bestPath, bestNodes.toSet, winnerNodeSet, costs) + BestImpl( + ras, + rootGroupId, + bestPath, + n => bestNodeSet.contains(n.toHashKey), + n => winnerNodeSet.contains(n.toHashKey), + costs) } private case class BestImpl[T <: AnyRef]( ras: Ras[T], override val rootGroupId: Int, override val path: KnownCostPath[T], - override val bestNodes: Set[InGroupNode[T]], - override val winnerNodes: Set[InGroupNode[T]], + override val bestNodes: InGroupNode[T] => Boolean, + override val winnerNodes: InGroupNode[T] => Boolean, override val costs: InGroupNode[T] => Option[Cost]) extends Best[T] diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/BestFinder.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/BestFinder.scala index 4ec7e09f556f..0912ab536df7 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/BestFinder.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/BestFinder.scala @@ -40,9 +40,12 @@ object BestFinder { } case class KnownCostGroup[T <: AnyRef]( - nodeToCost: Map[CanonicalNode[T], KnownCostPath[T]], + nodes: Iterable[CanonicalNode[T]], + nodeToCost: CanonicalNode[T] => Option[KnownCostPath[T]], bestNode: CanonicalNode[T]) { - def best(): KnownCostPath[T] = nodeToCost(bestNode) + def best(): KnownCostPath[T] = { + nodeToCost(bestNode).get + } } case class KnownCostCluster[T <: AnyRef](groupToCost: Map[Int, KnownCostGroup[T]]) @@ -52,17 +55,21 @@ object BestFinder { allGroups: Seq[RasGroup[T]], group: RasGroup[T], groupToCosts: Map[Int, KnownCostGroup[T]]): Best[T] = { + val bestPath = groupToCosts(group.id()).best() val bestRoot = bestPath.rasPath.node() val winnerNodes = groupToCosts.map { case (id, g) => InGroupNode(id, g.bestNode) }.toSeq - val costsMap = mutable.Map[InGroupNode[T], Cost]() + val costsMap = mutable.Map[InGroupNode.HashKey, Cost]() groupToCosts.foreach { case (gid, g) => - g.nodeToCost.foreach { - case (n, c) => - costsMap += (InGroupNode(gid, n) -> c.cost) + g.nodes.foreach { + n => + val c = g.nodeToCost(n) + if (c.nonEmpty) { + costsMap += (InGroupNode(gid, n).toHashKey -> c.get.cost) + } } } - Best(ras, group.id(), bestPath, winnerNodes, costsMap.get) + Best(ras, group.id(), bestPath, winnerNodes, ign => costsMap.get(ign.toHashKey)) } } diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/GroupBasedBestFinder.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/GroupBasedBestFinder.scala index 7d2d807ff824..6db3600de976 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/GroupBasedBestFinder.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/best/GroupBasedBestFinder.scala @@ -23,6 +23,8 @@ import org.apache.gluten.ras.dp.{DpGroupAlgo, DpGroupAlgoDef} import org.apache.gluten.ras.memo.MemoState import org.apache.gluten.ras.path.{PathKeySet, RasPath} +import java.util + // The best path's each sub-path is considered optimal in its own group. private class GroupBasedBestFinder[T <: AnyRef]( ras: Ras[T], @@ -94,21 +96,34 @@ private object GroupBasedBestFinder { override def solveGroup( group: RasGroup[T], nodesOutput: InGroupNode[T] => Option[KnownCostPath[T]]): Option[KnownCostGroup[T]] = { + import scala.collection.JavaConverters._ + val nodes = group.nodes(memoState) // Allow unsolved children nodes while solving group. - val flatNodesOutput = - nodes.flatMap(n => nodesOutput(InGroupNode(group.id(), n)).map(kcp => n -> kcp)).toMap + val flatNodesOutput = new util.IdentityHashMap[CanonicalNode[T], KnownCostPath[T]]() + + nodes + .flatMap(n => nodesOutput(InGroupNode(group.id(), n)).map(kcp => n -> kcp)) + .foreach { + case (n, kcp) => + assert(!flatNodesOutput.containsKey(n)) + flatNodesOutput.put(n, kcp) + } if (flatNodesOutput.isEmpty) { return None } - val bestPath = flatNodesOutput.values.reduce { + val bestPath = flatNodesOutput.values.asScala.reduce { (left, right) => Ordering .by((cp: KnownCostPath[T]) => cp.cost)(costComparator) .min(left, right) } - Some(KnownCostGroup(flatNodesOutput, bestPath.rasPath.node().self().asCanonical())) + Some( + KnownCostGroup( + nodes, + n => Option(flatNodesOutput.get(n)), + bestPath.rasPath.node().self().asCanonical())) } override def solveNodeOnCycle(node: InGroupNode[T]): Option[KnownCostPath[T]] = diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpClusterAlgo.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpClusterAlgo.scala index 95f453f47f15..e90ba448b833 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpClusterAlgo.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpClusterAlgo.scala @@ -73,7 +73,7 @@ object DpClusterAlgo { clusterAlgoDef: DpClusterAlgoDef[T, NodeOutput, ClusterOutput]) extends DpZipperAlgoDef[InClusterNode[T], RasClusterKey, NodeOutput, ClusterOutput] { override def idOfX(x: InClusterNode[T]): Any = { - x + x.toHashKey } override def idOfY(y: RasClusterKey): Any = { diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpGroupAlgo.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpGroupAlgo.scala index 6c1e998b6bbe..c824fda8e367 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpGroupAlgo.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpGroupAlgo.scala @@ -66,7 +66,7 @@ object DpGroupAlgo { groupAlgoDef: DpGroupAlgoDef[T, NodeOutput, GroupOutput]) extends DpZipperAlgoDef[InGroupNode[T], RasGroup[T], NodeOutput, GroupOutput] { override def idOfX(x: InGroupNode[T]): Any = { - x + x.toHashKey } override def idOfY(y: RasGroup[T]): Any = { diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala index 391e7f1962d6..1be728ae6d11 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpPlanner.scala @@ -21,7 +21,7 @@ import org.apache.gluten.ras.Best.KnownCostPath import org.apache.gluten.ras.best.BestFinder import org.apache.gluten.ras.dp.DpZipperAlgo.Adjustment.Panel import org.apache.gluten.ras.memo.{Memo, MemoTable} -import org.apache.gluten.ras.path.{PathFinder, RasPath} +import org.apache.gluten.ras.path.{InClusterPath, PathFinder, RasPath} import org.apache.gluten.ras.property.PropertySet import org.apache.gluten.ras.rule.{EnforcerRuleSet, RuleApplier, Shape} @@ -172,7 +172,7 @@ object DpPlanner { rule: RuleApplier[T], path: RasPath[T]): Unit = { val probe = memoTable.probe() - rule.apply(path) + rule.apply(InClusterPath(thisClusterKey, path)) val diff = probe.toDiff() val changedClusters = diff.changedClusters() if (changedClusters.isEmpty) { diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpZipperAlgo.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpZipperAlgo.scala index 821009982755..f28edd0dcb00 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpZipperAlgo.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/dp/DpZipperAlgo.scala @@ -608,7 +608,6 @@ object DpZipperAlgo { } private object XKey { - // Keep argument "ele" although it is unused. To give compiler type hint. def apply[X <: AnyRef, Y <: AnyRef, XOutput <: AnyRef, YOutput <: AnyRef]( algoDef: DpZipperAlgoDef[X, Y, XOutput, YOutput], x: X): XKey[X, Y, XOutput, YOutput] = { @@ -631,7 +630,6 @@ object DpZipperAlgo { } private object YKey { - // Keep argument "ele" although it is unused. To give compiler type hint. def apply[X <: AnyRef, Y <: AnyRef, XOutput <: AnyRef, YOutput <: AnyRef]( algoDef: DpZipperAlgoDef[X, Y, XOutput, YOutput], y: Y): YKey[X, Y, XOutput, YOutput] = { diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala index a9737eb02b16..47945fc1426f 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/exaustive/ExhaustivePlanner.scala @@ -107,8 +107,8 @@ object ExhaustivePlanner { finder.find(canonical).foreach(path => onFound(path)) } - private def applyRule(rule: RuleApplier[T], path: RasPath[T]): Unit = { - rule.apply(path) + private def applyRule(rule: RuleApplier[T], icp: InClusterPath[T]): Unit = { + rule.apply(icp) } private def applyRules(): Unit = { @@ -116,10 +116,17 @@ object ExhaustivePlanner { return } val shapes = rules.map(_.shape()) - allClusters - .flatMap(c => c.nodes()) - .foreach( - node => findPaths(node, shapes)(path => rules.foreach(rule => applyRule(rule, path)))) + memoState + .clusterLookup() + .foreach { + case (cKey, cluster) => + cluster + .nodes() + .foreach( + node => + findPaths(node, shapes)( + path => rules.foreach(rule => applyRule(rule, InClusterPath(cKey, path))))) + } } private def applyEnforcerRules(): Unit = { @@ -129,10 +136,11 @@ object ExhaustivePlanner { val enforcerRules = enforcerRuleSet.rulesOf(constraintSet) if (enforcerRules.nonEmpty) { val shapes = enforcerRules.map(_.shape()) - memoState.clusterLookup()(group.clusterKey()).nodes().foreach { + val cKey = group.clusterKey() + memoState.clusterLookup()(cKey).nodes().foreach { node => findPaths(node, shapes)( - path => enforcerRules.foreach(rule => applyRule(rule, path))) + path => enforcerRules.foreach(rule => applyRule(rule, InClusterPath(cKey, path)))) } } } diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala index 945e653ebb0f..e3ae03ebfda2 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/ForwardMemoTable.scala @@ -33,6 +33,8 @@ class ForwardMemoTable[T <: AnyRef] private (override val ras: Ras[T]) private val clusterKeyBuffer: mutable.ArrayBuffer[IntClusterKey] = mutable.ArrayBuffer() private val clusterBuffer: mutable.ArrayBuffer[MutableRasCluster[T]] = mutable.ArrayBuffer() private val clusterDisjointSet: IndexDisjointSet = IndexDisjointSet() + private val clusterDummyGroupBuffer = mutable.ArrayBuffer[RasGroup[T]]() + private val groupLookup: mutable.ArrayBuffer[mutable.Map[PropertySet[T], RasGroup[T]]] = mutable.ArrayBuffer() @@ -46,14 +48,22 @@ class ForwardMemoTable[T <: AnyRef] private (override val ras: Ras[T]) override def newCluster(metadata: Metadata): RasClusterKey = { checkBufferSizes() - val key = IntClusterKey(clusterBuffer.size, metadata) + val clusterId = clusterBuffer.size + val key = IntClusterKey(clusterId, metadata) clusterKeyBuffer += key clusterBuffer += MutableRasCluster(ras, metadata) clusterDisjointSet.grow() groupLookup += mutable.Map() + // Normal groups start with ID 0, so it's safe to use negative IDs for dummy groups. + clusterDummyGroupBuffer += RasGroup(ras, key, -clusterId, ras.propertySetFactory().any()) key } + override def dummyGroupOf(key: RasClusterKey): RasGroup[T] = { + val ancestor = ancestorClusterIdOf(key) + clusterDummyGroupBuffer(ancestor) + } + override def groupOf(key: RasClusterKey, propSet: PropertySet[T]): RasGroup[T] = { val ancestor = ancestorClusterIdOf(key) val lookup = groupLookup(ancestor) @@ -75,7 +85,11 @@ class ForwardMemoTable[T <: AnyRef] private (override val ras: Ras[T]) } override def addToCluster(key: RasClusterKey, node: CanonicalNode[T]): Unit = { - getCluster(key).add(node) + val cluster = getCluster(key) + if (cluster.contains(node)) { + return + } + cluster.add(node) memoWriteCount += 1 } @@ -142,6 +156,7 @@ class ForwardMemoTable[T <: AnyRef] private (override val ras: Ras[T]) assert(clusterKeyBuffer.size == clusterBuffer.size) assert(clusterKeyBuffer.size == clusterDisjointSet.size) assert(clusterKeyBuffer.size == groupLookup.size) + assert(clusterKeyBuffer.size == clusterDummyGroupBuffer.size) } override def probe(): MemoTable.Probe[T] = new ForwardMemoTable.Probe[T](this) diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala index a77293586a73..66626b756c30 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/Memo.scala @@ -17,17 +17,19 @@ package org.apache.gluten.ras.memo import org.apache.gluten.ras._ +import org.apache.gluten.ras.Ras.UnsafeKey import org.apache.gluten.ras.RasCluster.ImmutableRasCluster import org.apache.gluten.ras.property.PropertySet -import org.apache.gluten.ras.util.CanonicalNodeMap import org.apache.gluten.ras.vis.GraphvizVisualizer +import scala.collection.mutable + trait MemoLike[T <: AnyRef] { def memorize(node: T, constraintSet: PropertySet[T]): RasGroup[T] } trait Closure[T <: AnyRef] { - def openFor(node: CanonicalNode[T]): MemoLike[T] + def openFor(cKey: RasClusterKey): MemoLike[T] } trait Memo[T <: AnyRef] extends Closure[T] with MemoLike[T] { @@ -51,82 +53,61 @@ object Memo { private class RasMemo[T <: AnyRef](val ras: Ras[T]) extends UnsafeMemo[T] { import RasMemo._ private val memoTable: MemoTable.Writable[T] = MemoTable.create(ras) - private val cache: NodeToClusterMap[T] = new NodeToClusterMap(ras) + private val cache = mutable.Map[MemoCacheKey[T], RasClusterKey]() private def newCluster(metadata: Metadata): RasClusterKey = { memoTable.newCluster(metadata) } private def addToCluster(clusterKey: RasClusterKey, can: CanonicalNode[T]): Unit = { - assert(!cache.contains(can)) - cache.put(can, clusterKey) memoTable.addToCluster(clusterKey, can) } - // Replace node's children with node groups. When a group doesn't exist, create it. - private def canonizeUnsafe(node: T, constraintSet: PropertySet[T], depth: Int): T = { - assert(depth >= 1) - if (depth > 1) { - return ras.withNewChildren( - node, - ras.planModel - .childrenOf(node) - .zip(ras.propertySetFactory().childrenConstraintSets(constraintSet, node)) - .map { - case (child, constraintSet) => - canonizeUnsafe(child, constraintSet, depth - 1) - } - ) + private def clusterOfUnsafe(metadata: Metadata, cacheKey: MemoCacheKey[T]): RasClusterKey = { + if (cache.contains(cacheKey)) { + cache(cacheKey) + } else { + // Node not yet added to cluster. + val cluster = newCluster(metadata) + cache += (cacheKey -> cluster) + cluster } - assert(depth == 1) - val childrenGroups: Seq[RasGroup[T]] = ras.planModel - .childrenOf(node) - .zip(ras.propertySetFactory().childrenConstraintSets(constraintSet, node)) - .map { - case (child, childConstraintSet) => - memorize(child, childConstraintSet) - } - val newNode = - ras.withNewChildren(node, childrenGroups.map(group => group.self())) - newNode } - private def canonize(node: T, constraintSet: PropertySet[T]): CanonicalNode[T] = { - CanonicalNode(ras, canonizeUnsafe(node, constraintSet, 1)) + private def dummyGroupOf(clusterKey: RasClusterKey): RasGroup[T] = { + memoTable.dummyGroupOf(clusterKey) + } + + private def toCacheKeyUnsafe(n: T): MemoCacheKey[T] = { + MemoCacheKey(ras, n) } - private def insert(n: T, constraintSet: PropertySet[T]): RasClusterKey = { - if (ras.planModel.isGroupLeaf(n)) { - val plainGroup = memoTable.allGroups()(ras.planModel.getGroupId(n)) - return plainGroup.clusterKey() + private def prepareInsert(n: T): Prepare[T] = { + if (ras.isGroupLeaf(n)) { + val group = memoTable.allGroups()(ras.planModel.getGroupId(n)) + return Prepare.cluster(this, group.clusterKey()) } - val node = canonize(n, constraintSet) + val childrenPrepares = ras.planModel.childrenOf(n).map(child => prepareInsert(child)) - if (cache.contains(node)) { - cache.get(node) - } else { - // Node not yet added to cluster. - val meta = ras.metadataModel.metadataOf(node.self()) - val clusterKey = newCluster(meta) - addToCluster(clusterKey, node) - clusterKey - } + val canUnsafe = ras.withNewChildren( + n, + childrenPrepares.map(childPrepare => dummyGroupOf(childPrepare.clusterKey()).self())) + + val cacheKey = toCacheKeyUnsafe(canUnsafe) + + val clusterKey = clusterOfUnsafe(ras.metadataModel.metadataOf(n), cacheKey) + + Prepare.tree(this, clusterKey, childrenPrepares) } override def memorize(node: T, constraintSet: PropertySet[T]): RasGroup[T] = { - val clusterKey = insert(node, constraintSet) - val prevGroupCount = memoTable.allGroups().size - val out = memoTable.groupOf(clusterKey, constraintSet) - val newGroupCount = memoTable.allGroups().size - assert(newGroupCount >= prevGroupCount) - out + val prepare = prepareInsert(node) + prepare.doInsert(node, constraintSet) } - override def openFor(node: CanonicalNode[T]): MemoLike[T] = { - assert(cache.contains(node)) - val targetCluster = cache.get(node) - new InCusterMemo[T](this, targetCluster) + override def openFor(cKey: RasClusterKey): MemoLike[T] = { + new InCusterMemo[T](this, cKey) } override def newState(): MemoState[T] = { @@ -141,37 +122,116 @@ object Memo { } private object RasMemo { - private class InCusterMemo[T <: AnyRef](parent: RasMemo[T], preparedCluster: RasClusterKey) + private class InCusterMemo[T <: AnyRef](parent: RasMemo[T], targetCluster: RasClusterKey) extends MemoLike[T] { + private val ras = parent.ras + + private def prepareInsert(node: T): Prepare[T] = { + assert(!ras.isGroupLeaf(node)) + + val childrenPrepares = + ras.planModel.childrenOf(node).map(child => parent.prepareInsert(child)) + + val canUnsafe = ras.withNewChildren( + node, + childrenPrepares.map { + childPrepare => parent.dummyGroupOf(childPrepare.clusterKey()).self() + }) + + val cacheKey = parent.toCacheKeyUnsafe(canUnsafe) + + if (!parent.cache.contains(cacheKey)) { + // The new node was not added to memo yet. Add it to the target cluster. + parent.cache += (cacheKey -> targetCluster) + return Prepare.tree(parent, targetCluster, childrenPrepares) + } + + // The new node already memorized to memo. - private def insert(node: T, constraintSet: PropertySet[T]): Unit = { - val can = parent.canonize(node, constraintSet) - if (parent.cache.contains(can)) { - val cachedCluster = parent.cache.get(can) - if (cachedCluster == preparedCluster) { - return - } - // The new node already memorized to memo, but in the different cluster - // with the input node. Merge the two clusters. - // - // TODO: Traversal up the tree to do more merges. - parent.memoTable.mergeClusters(cachedCluster, preparedCluster) - // Since new node already memorized, we don't have to add it to either of the clusters - // anymore. - return + val cachedCluster = parent.cache(cacheKey) + if (cachedCluster == targetCluster) { + // The new node already memorized to memo and in the target cluster. + return Prepare.tree(parent, targetCluster, childrenPrepares) } - parent.addToCluster(preparedCluster, can) + // The new node already memorized to memo, but in the different cluster. + // Merge the two clusters. + // + // TODO: Traverse up the tree to do more merges. + parent.memoTable.mergeClusters(cachedCluster, targetCluster) + Prepare.tree(parent, targetCluster, childrenPrepares) } override def memorize(node: T, constraintSet: PropertySet[T]): RasGroup[T] = { - insert(node, constraintSet) - parent.memoTable.groupOf(preparedCluster, constraintSet) + val prepare = prepareInsert(node) + prepare.doInsert(node, constraintSet) } } + + private trait Prepare[T <: AnyRef] { + def clusterKey(): RasClusterKey + def doInsert(node: T, constraintSet: PropertySet[T]): RasGroup[T] + } + + private object Prepare { + def tree[T <: AnyRef]( + memo: RasMemo[T], + cKey: RasClusterKey, + children: Seq[Prepare[T]]): Prepare[T] = { + new TreePrepare[T](memo, cKey, children) + } + + def cluster[T <: AnyRef](memo: RasMemo[T], cKey: RasClusterKey): Prepare[T] = { + new ClusterPrepare[T](memo, cKey) + } + + private class TreePrepare[T <: AnyRef]( + memo: RasMemo[T], + override val clusterKey: RasClusterKey, + children: Seq[Prepare[T]]) + extends Prepare[T] { + private val ras = memo.ras + + override def doInsert(node: T, constraintSet: PropertySet[T]): RasGroup[T] = { + assert(!ras.isGroupLeaf(node)) + val childrenGroups = children + .zip(ras.planModel.childrenOf(node)) + .zip(ras.propertySetFactory().childrenConstraintSets(constraintSet, node)) + .map { + case ((childPrepare, child), childConstraintSet) => + childPrepare.doInsert(child, childConstraintSet) + } + + val canUnsafe = ras.withNewChildren(node, childrenGroups.map(group => group.self())) + val can = CanonicalNode(ras, canUnsafe) + + memo.addToCluster(clusterKey, can) + + val group = memo.memoTable.groupOf(clusterKey, constraintSet) + group + } + } + + private class ClusterPrepare[T <: AnyRef](memo: RasMemo[T], cKey: RasClusterKey) + extends Prepare[T] { + private val ras = memo.ras + override def doInsert(node: T, constraintSet: PropertySet[T]): RasGroup[T] = { + assert(ras.isGroupLeaf(node)) + memo.memoTable.groupOf(cKey, constraintSet) + } + + override def clusterKey(): RasClusterKey = cKey + } + } + } + + private object MemoCacheKey { + def apply[T <: AnyRef](ras: Ras[T], self: T): MemoCacheKey[T] = { + assert(ras.isCanonical(self)) + MemoCacheKey[T](ras.toUnsafeKey(self)) + } } - private class NodeToClusterMap[T <: AnyRef](ras: Ras[T]) - extends CanonicalNodeMap[T, RasClusterKey](ras) + private case class MemoCacheKey[T <: AnyRef] private (delegate: UnsafeKey[T]) } trait MemoStore[T <: AnyRef] { diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala index b54bd88117c8..3baba8eae22e 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/memo/MemoTable.scala @@ -44,6 +44,7 @@ object MemoTable { trait Writable[T <: AnyRef] extends MemoTable[T] { def newCluster(metadata: Metadata): RasClusterKey def groupOf(key: RasClusterKey, propertySet: PropertySet[T]): RasGroup[T] + def dummyGroupOf(key: RasClusterKey): RasGroup[T] def addToCluster(key: RasClusterKey, node: CanonicalNode[T]): Unit def mergeClusters(one: RasClusterKey, other: RasClusterKey): Unit diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/RasPath.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/RasPath.scala index ca712cec4010..61fa22e5eaad 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/RasPath.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/path/RasPath.scala @@ -37,7 +37,7 @@ object RasPath { object PathNode { def apply[T <: AnyRef](node: RasNode[T], children: Seq[PathNode[T]]): PathNode[T] = { - PathNodeImpl(node, children) + new PathNodeImpl(node, children) } } @@ -61,7 +61,7 @@ object RasPath { keys: PathKeySet, height: Int, node: RasPath.PathNode[T]): RasPath[T] = { - RasPathImpl(ras, keys, height, node) + new RasPathImpl(ras, keys, height, node) } // Returns none if children doesn't share at least one path key. @@ -103,25 +103,6 @@ object RasPath { PathNode(canonical, canonical.getChildrenGroups(allGroups).map(g => PathNode(g, List.empty)))) } - // Aggregates paths that have same shape but different keys together. - // Currently not in use because of bad performance. - def aggregate[T <: AnyRef](ras: Ras[T], paths: Iterable[RasPath[T]]): Iterable[RasPath[T]] = { - // Scala has specialized optimization against small set of input of group-by. - // So it's better only to pass small inputs to this method if possible. - val grouped = paths.groupBy(_.node()) - grouped.map { - case (node, paths) => - val heights = paths.map(_.height()).toSeq.distinct - assert(heights.size == 1) - val height = heights.head - val keys = paths.map(_.keys().keys()).reduce[Set[PathKey]] { - case (one, other) => - one.union(other) - } - RasPath(ras, PathKeySet(keys), height, node) - } - } - def cartesianProduct[T <: AnyRef]( ras: Ras[T], canonical: CanonicalNode[T], @@ -171,12 +152,12 @@ object RasPath { } } - private case class PathNodeImpl[T <: AnyRef]( + private class PathNodeImpl[T <: AnyRef]( override val self: RasNode[T], override val children: Seq[PathNode[T]]) extends PathNode[T] - private case class RasPathImpl[T <: AnyRef]( + private class RasPathImpl[T <: AnyRef]( override val ras: Ras[T], override val keys: PathKeySet, override val height: Int, @@ -193,3 +174,19 @@ object RasPath { override def plan(): T = built } } + +trait InClusterPath[T <: AnyRef] { + def cluster(): RasClusterKey + def path(): RasPath[T] +} + +object InClusterPath { + def apply[T <: AnyRef](cluster: RasClusterKey, path: RasPath[T]): InClusterPath[T] = { + new InClusterPathImpl(cluster, path) + } + + private class InClusterPathImpl[T <: AnyRef]( + override val cluster: RasClusterKey, + override val path: RasPath[T]) + extends InClusterPath[T] +} diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala index b99001e93cf2..0a7bf0c7685b 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/rule/RuleApplier.scala @@ -17,14 +17,14 @@ package org.apache.gluten.ras.rule import org.apache.gluten.ras._ +import org.apache.gluten.ras.Ras.UnsafeKey import org.apache.gluten.ras.memo.Closure -import org.apache.gluten.ras.path.RasPath -import org.apache.gluten.ras.util.CanonicalNodeMap +import org.apache.gluten.ras.path.InClusterPath import scala.collection.mutable trait RuleApplier[T <: AnyRef] { - def apply(path: RasPath[T]): Unit + def apply(icp: InClusterPath[T]): Unit def shape(): Shape[T] } @@ -42,25 +42,27 @@ object RuleApplier { private class RegularRuleApplier[T <: AnyRef](ras: Ras[T], closure: Closure[T], rule: RasRule[T]) extends RuleApplier[T] { - private val cache = new CanonicalNodeMap[T, mutable.Set[T]](ras) + private val deDup = mutable.Map[RasClusterKey, mutable.Set[UnsafeKey[T]]]() - override def apply(path: RasPath[T]): Unit = { - val can = path.node().self().asCanonical() + override def apply(icp: InClusterPath[T]): Unit = { + val cKey = icp.cluster() + val path = icp.path() val plan = path.plan() - val appliedPlans = cache.getOrElseUpdate(can, mutable.Set()) - if (appliedPlans.contains(plan)) { + val appliedPlans = deDup.getOrElseUpdate(cKey, mutable.Set()) + val pKey = ras.toUnsafeKey(plan) + if (appliedPlans.contains(pKey)) { return } - apply0(can, plan) - appliedPlans += plan + apply0(cKey, plan) + appliedPlans += pKey } - private def apply0(can: CanonicalNode[T], plan: T): Unit = { + private def apply0(cKey: RasClusterKey, plan: T): Unit = { val equivalents = rule.shift(plan) equivalents.foreach { equiv => closure - .openFor(can) + .openFor(cKey) .memorize(equiv, ras.propertySetFactory().get(equiv)) } } @@ -73,32 +75,35 @@ object RuleApplier { closure: Closure[T], rule: EnforcerRule[T]) extends RuleApplier[T] { - private val cache = new CanonicalNodeMap[T, mutable.Set[T]](ras) + private val deDup = mutable.Map[RasClusterKey, mutable.Set[UnsafeKey[T]]]() private val constraint = rule.constraint() private val constraintDef = constraint.definition() - override def apply(path: RasPath[T]): Unit = { + override def apply(icp: InClusterPath[T]): Unit = { + val cKey = icp.cluster() + val path = icp.path() val can = path.node().self().asCanonical() if (can.propSet().get(constraintDef).satisfies(constraint)) { return } val plan = path.plan() - val appliedPlans = cache.getOrElseUpdate(can, mutable.Set()) - if (appliedPlans.contains(plan)) { + val pKey = ras.toUnsafeKey(plan) + val appliedPlans = deDup.getOrElseUpdate(cKey, mutable.Set()) + if (appliedPlans.contains(pKey)) { return } - apply0(can, plan) - appliedPlans += plan + apply0(cKey, plan) + appliedPlans += pKey } - private def apply0(can: CanonicalNode[T], plan: T): Unit = { + private def apply0(cKey: RasClusterKey, plan: T): Unit = { val propSet = ras.propertySetFactory().get(plan) val constraintSet = propSet.withProp(constraint) val equivalents = rule.shift(plan) equivalents.foreach { equiv => closure - .openFor(can) + .openFor(cKey) .memorize(equiv, constraintSet) } } @@ -110,11 +115,11 @@ object RuleApplier { extends RuleApplier[T] { private val ruleShape = rule.shape() - override def apply(path: RasPath[T]): Unit = { - if (!ruleShape.identify(path)) { + override def apply(icp: InClusterPath[T]): Unit = { + if (!ruleShape.identify(icp.path())) { return } - rule.apply(path) + rule.apply(icp) } override def shape(): Shape[T] = ruleShape diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/util/NodeMap.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/util/NodeMap.scala deleted file mode 100644 index 887e00bdcdf5..000000000000 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/util/NodeMap.scala +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.gluten.ras.util - -import org.apache.gluten.ras.{CanonicalNode, Ras} - -import scala.collection.mutable - -// Arbitrary node key. -class NodeKey[T <: AnyRef](ras: Ras[T], val node: T) { - override def hashCode(): Int = ras.planModel.hashCode(node) - - override def equals(obj: Any): Boolean = { - obj match { - case other: NodeKey[T] => ras.planModel.equals(node, other.node) - case _ => false - } - } - - override def toString(): String = s"NodeKey($node)" -} - -// Canonical node map. -class CanonicalNodeMap[T <: AnyRef, V](ras: Ras[T]) { - private val map: mutable.Map[NodeKey[T], V] = mutable.Map() - - def contains(node: CanonicalNode[T]): Boolean = { - map.contains(keyOf(node)) - } - - def put(node: CanonicalNode[T], value: V): Unit = { - map.put(keyOf(node), value) - } - - def get(node: CanonicalNode[T]): V = { - map(keyOf(node)) - } - - def getOrElseUpdate(node: CanonicalNode[T], op: => V): V = { - map.getOrElseUpdate(keyOf(node), op) - } - - private def keyOf(node: CanonicalNode[T]): NodeKey[T] = { - new NodeKey(ras, node.self()) - } -} diff --git a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala index 600a61edc9dc..11f6051b0cac 100644 --- a/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala +++ b/gluten-ras/common/src/main/scala/org/apache/gluten/ras/vis/GraphvizVisualizer.scala @@ -43,13 +43,13 @@ class GraphvizVisualizer[T <: AnyRef](ras: Ras[T], memoState: MemoState[T], best object IsBestNode { def unapply(nodeAndGroupToTest: (CanonicalNode[T], RasGroup[T])): Boolean = { - bestNodes.contains(InGroupNode(nodeAndGroupToTest._2.id(), nodeAndGroupToTest._1)) + bestNodes(InGroupNode(nodeAndGroupToTest._2.id(), nodeAndGroupToTest._1)) } } object IsWinnerNode { def unapply(nodeAndGroupToTest: (CanonicalNode[T], RasGroup[T])): Boolean = { - winnerNodes.contains(InGroupNode(nodeAndGroupToTest._2.id(), nodeAndGroupToTest._1)) + winnerNodes(InGroupNode(nodeAndGroupToTest._2.id(), nodeAndGroupToTest._1)) } } diff --git a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/OperationSuite.scala b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/OperationSuite.scala index acd96442c38c..f1c319873355 100644 --- a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/OperationSuite.scala +++ b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/OperationSuite.scala @@ -97,7 +97,7 @@ class OperationSuite extends AnyFunSuite { val ras = Ras[TestNode]( - PlanModelImpl, + planModel, CostModelImpl, MetadataModelImpl, PropertyModelImpl, @@ -108,7 +108,7 @@ class OperationSuite extends AnyFunSuite { val optimized = planner.plan() assert(optimized == Unary2(49, Leaf2(29))) - planModel.assertPlanOpsLte((200, 50, 50, 50)) + planModel.assertPlanOpsLte((200, 50, 100, 50)) val state = planner.newState() val allPaths = state.memoState().collectAllPaths(RasPath.INF_DEPTH).toSeq @@ -127,7 +127,7 @@ class OperationSuite extends AnyFunSuite { val ras = Ras[TestNode]( - PlanModelImpl, + planModel, CostModelImpl, MetadataModelImpl, PropertyModelImpl, @@ -163,7 +163,7 @@ class OperationSuite extends AnyFunSuite { val ras = Ras[TestNode]( - PlanModelImpl, + planModel, CostModelImpl, MetadataModelImpl, PropertyModelImpl, diff --git a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/PropertySuite.scala b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/PropertySuite.scala index 8a68bbba8de6..e48604116a60 100644 --- a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/PropertySuite.scala +++ b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/PropertySuite.scala @@ -19,6 +19,7 @@ package org.apache.gluten.ras import org.apache.gluten.ras.Best.BestNotFoundException import org.apache.gluten.ras.RasConfig.PlannerType import org.apache.gluten.ras.RasSuiteBase._ +import org.apache.gluten.ras.memo.Memo import org.apache.gluten.ras.property.PropertySet import org.apache.gluten.ras.rule.{RasRule, Shape, Shapes} @@ -37,6 +38,30 @@ abstract class PropertySuite extends AnyFunSuite { protected def conf: RasConfig + test("Group memo - cache") { + val ras = + Ras[TestNode]( + PlanModelImpl, + CostModelImpl, + MetadataModelImpl, + NodeTypePropertyModelWithOutEnforcerRules, + ExplainImpl, + RasRule.Factory.none()) + .withNewConfig(_ => conf) + + val memo = Memo(ras) + + memo.memorize(ras, PassNodeType(1, PassNodeType(1, PassNodeType(1, TypedLeaf(TypeA, 1))))) + val leafGroup = memo.memorize(ras, TypedLeaf(TypeA, 1)) + memo + .openFor(leafGroup.clusterKey()) + .memorize(ras, TypedLeaf(TypeB, 1)) + memo.memorize(ras, PassNodeType(1, PassNodeType(1, PassNodeType(1, TypedLeaf(TypeB, 1))))) + val state = memo.newState() + assert(state.allClusters().size == 4) + assert(state.getGroupCount() == 8) + } + test(s"Get property") { val leaf = PLeaf(10, DummyProperty(0)) val unary = PUnary(5, DummyProperty(0), leaf) @@ -112,7 +137,7 @@ abstract class PropertySuite extends AnyFunSuite { TypedLeaf(TypeB, 10))) } - ignore(s"Memo cache hit - (A, B)") { + test(s"Memo cache hit - (A, B)") { object ReplaceLeafAByLeafBRule extends RasRule[TestNode] { override def shift(node: TestNode): Iterable[TestNode] = { node match { @@ -163,8 +188,8 @@ abstract class PropertySuite extends AnyFunSuite { val out = planner.plan() assert(out == TypedLeaf(TypeA, 1)) - // FIXME: Cluster 2 and 1 are currently able to merge but it's better to - // have them identified as the same right after HitCacheOp is applied + // Cluster 2 and 1 are able to merge but we'd make sure + // they are identified as the same right after HitCacheOp is applied val clusterCount = planner.newState().memoState().allClusters().size assert(clusterCount == 2) } @@ -531,6 +556,7 @@ object PropertySuite { } object DummyPropertyDef extends PropertyDef[TestNode, DummyProperty] { + override def any(): DummyProperty = DummyProperty(Int.MinValue) override def getProperty(plan: TestNode): DummyProperty = { plan match { case Group(_, _, _) => throw new IllegalStateException() @@ -669,6 +695,8 @@ object PropertySuite { } override def toString: String = "NodeTypeDef" + + override def any(): NodeType = TypeAny } trait NodeType extends Property[TestNode] { diff --git a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala index abb8bdecd0c2..0ad82518128f 100644 --- a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala +++ b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/RasSuite.scala @@ -66,8 +66,7 @@ abstract class RasSuite extends AnyFunSuite { val group = memo.memorize(ras, Unary(50, Unary(50, Leaf(30)))) val state = memo.newState() assert(group.nodes(state).size == 1) - val can = group.nodes(state).head.asCanonical() - memo.openFor(can).memorize(ras, Unary(30, Leaf(90))) + memo.openFor(group.clusterKey()).memorize(ras, Unary(30, Leaf(90))) assert(memo.newState().allGroups().size == 4) } @@ -87,8 +86,7 @@ abstract class RasSuite extends AnyFunSuite { assert(group.nodes(state).size == 1) val leaf40Group = memo.memorize(ras, Leaf(40)) assert(leaf40Group.nodes(state).size == 1) - val can = leaf40Group.nodes(state).head.asCanonical() - memo.openFor(can).memorize(ras, Leaf(30)) + memo.openFor(leaf40Group.clusterKey()).memorize(ras, Leaf(30)) assert(memo.newState().allGroups().size == 3) } @@ -108,8 +106,7 @@ abstract class RasSuite extends AnyFunSuite { assert(group.nodes(state).size == 1) val leaf40Group = memo.memorize(ras, Leaf(40)) assert(leaf40Group.nodes(state).size == 1) - val can = leaf40Group.nodes(state).head.asCanonical() - memo.openFor(can).memorize(ras, Leaf(30)) + memo.openFor(leaf40Group.clusterKey()).memorize(ras, Leaf(30)) assert(memo.newState().allGroups().size == 5) } diff --git a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/path/RasPathSuite.scala b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/path/RasPathSuite.scala index e092ea4f23f3..8158aec16886 100644 --- a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/path/RasPathSuite.scala +++ b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/path/RasPathSuite.scala @@ -16,7 +16,7 @@ */ package org.apache.gluten.ras.path -import org.apache.gluten.ras.Ras +import org.apache.gluten.ras.{CanonicalNode, Ras} import org.apache.gluten.ras.RasSuiteBase._ import org.apache.gluten.ras.mock.MockRasPath import org.apache.gluten.ras.rule.RasRule @@ -26,7 +26,7 @@ import org.scalatest.funsuite.AnyFunSuite class RasPathSuite extends AnyFunSuite { import RasPathSuite._ - test("Path aggregate - empty") { + test("Cartesian product - empty") { val ras = Ras[TestNode]( PlanModelImpl, @@ -35,10 +35,21 @@ class RasPathSuite extends AnyFunSuite { PropertyModelImpl, ExplainImpl, RasRule.Factory.reuse(List.empty)) - assert(RasPath.aggregate(ras, List.empty) == List.empty) + assert( + RasPath.cartesianProduct( + ras, + CanonicalNode(ras, Binary("b", ras.dummyGroupLeaf(), ras.dummyGroupLeaf())), + List( + List.empty, + List( + MockRasPath.mock( + ras, + Leaf("l", 1), + PathKeySet(Set(DummyPathKey(3))) + ))) + ) == List.empty) } - - test("Path aggregate") { + test("Cartesian product") { val ras = Ras[TestNode]( PlanModelImpl, @@ -54,6 +65,7 @@ class RasPathSuite extends AnyFunSuite { val n4 = "n4" val n5 = "n5" val n6 = "n6" + val path1 = MockRasPath.mock( ras, Unary(n5, Leaf(n6, 1)), @@ -66,31 +78,37 @@ class RasPathSuite extends AnyFunSuite { ) val path3 = MockRasPath.mock( ras, - Unary(n1, Unary(n2, Leaf(n3, 1))), - PathKeySet(Set(DummyPathKey(1), DummyPathKey(2))) + Leaf(n6, 1), + PathKeySet(Set(DummyPathKey(1))) ) val path4 = MockRasPath.mock( ras, - Unary(n1, Unary(n2, Leaf(n3, 1))), - PathKeySet(Set(DummyPathKey(4))) + Leaf(n3, 1), + PathKeySet(Set(DummyPathKey(3))) ) + val path5 = MockRasPath.mock( ras, - Unary(n5, Leaf(n6, 1)), + Unary(n2, Leaf(n3, 1)), PathKeySet(Set(DummyPathKey(4))) ) - val out = RasPath - .aggregate(ras, List(path1, path2, path3, path4, path5)) - .toSeq - .sortBy(_.height()) - assert(out.size == 2) - assert(out.head.height() == 2) - assert(out.head.plan() == Unary(n5, Leaf(n6, 1))) - assert(out.head.keys() == PathKeySet(Set(DummyPathKey(1), DummyPathKey(3), DummyPathKey(4)))) - assert(out(1).height() == 3) - assert(out(1).plan() == Unary(n1, Unary(n2, Leaf(n3, 1)))) - assert(out(1).keys() == PathKeySet(Set(DummyPathKey(1), DummyPathKey(2), DummyPathKey(4)))) + val product = RasPath.cartesianProduct( + ras, + CanonicalNode(ras, Binary(n4, ras.dummyGroupLeaf(), ras.dummyGroupLeaf())), + List( + List(path1, path2), + List(path3, path4, path5) + )) + + val out = product.toList + assert(out.size == 3) + + assert( + out.map(_.plan()) == List( + Binary(n4, Unary(n5, Leaf(n6, 1)), Leaf(n6, 1)), + Binary(n4, Unary(n5, Leaf(n6, 1)), Leaf(n3, 1)), + Binary(n4, Unary(n1, Unary(n2, Leaf(n3, 1))), Leaf(n6, 1)))) } } diff --git a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/specific/DistributedSuite.scala b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/specific/DistributedSuite.scala index cab3d1818e31..de71cba5bc0f 100644 --- a/gluten-ras/common/src/test/scala/org/apache/gluten/ras/specific/DistributedSuite.scala +++ b/gluten-ras/common/src/test/scala/org/apache/gluten/ras/specific/DistributedSuite.scala @@ -262,6 +262,8 @@ object DistributedSuite { case (d: Distribution, p: DNode) => p.getDistributionConstraints(d) case _ => throw new UnsupportedOperationException() } + + override def any(): Distribution = AnyDistribution } trait Ordering extends Property[TestNode] @@ -315,6 +317,8 @@ object DistributedSuite { case (o: Ordering, p: DNode) => p.getOrderingConstraints(o) case _ => throw new UnsupportedOperationException() } + + override def any(): Ordering = AnyOrdering } private class EnforceDistribution(distribution: Distribution) extends RasRule[TestNode] {