Skip to content

Commit

Permalink
[CORE][VL] RAS: Refactor memo cache to look up on cluster-canonical n…
Browse files Browse the repository at this point in the history
…ode rather than on group-canonical node (apache#5305)
  • Loading branch information
zhztheplayer authored Apr 8, 2024
1 parent e10ee1b commit 993e96a
Show file tree
Hide file tree
Showing 26 changed files with 453 additions and 278 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ object GlutenProperties {
val conv = getProperty(plan)
plan.children.map(_ => conv)
}

override def any(): Convention = Conventions.ANY
}

case class ConventionEnforcerRule(reqConv: Convention) extends RasRule[SparkPlan] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ package org.apache.gluten.ras
*/
trait MetadataModel[T <: AnyRef] {
def metadataOf(node: T): Metadata
def dummy(): Metadata
def verify(one: Metadata, other: Metadata): Unit

def dummy(): Metadata
}

trait Metadata {}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ trait Property[T <: AnyRef] {
}

trait PropertyDef[T <: AnyRef, P <: Property[T]] {
def any(): P
def getProperty(plan: T): P
def getChildrenConstraints(constraint: Property[T], plan: T): Seq[P]
}
Expand Down
48 changes: 39 additions & 9 deletions gluten-ras/common/src/main/scala/org/apache/gluten/ras/Ras.scala
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,10 @@ class Ras[T <: AnyRef] private (
ruleFactory)
}

// Normal groups start with ID 0, so it's safe to use -1 to do validation.
private val propSetFactory: PropertySetFactory[T] = PropertySetFactory(propertyModel, planModel)
// Normal groups start with ID 0, so it's safe to use Int.MinValue to do validation.
private val dummyGroup: T =
planModel.newGroupLeaf(-1, metadataModel.dummy(), PropertySet(Seq.empty))
planModel.newGroupLeaf(Int.MinValue, metadataModel.dummy(), propSetFactory.any())
private val infCost: Cost = costModel.makeInfCost()

validateModels()
Expand Down Expand Up @@ -123,8 +124,6 @@ class Ras[T <: AnyRef] private (
}
}

private val propSetFactory: PropertySetFactory[T] = PropertySetFactory(this)

override def newPlanner(
plan: T,
constraintSet: PropertySet[T],
Expand Down Expand Up @@ -171,6 +170,8 @@ class Ras[T <: AnyRef] private (
private[ras] def getInfCost(): Cost = infCost

private[ras] def isInfCost(cost: Cost) = costModel.costComparator().equiv(cost, infCost)

private[ras] def toUnsafeKey(node: T): UnsafeKey[T] = UnsafeKey(this, node)
}

object Ras {
Expand All @@ -192,16 +193,29 @@ object Ras {
}

trait PropertySetFactory[T <: AnyRef] {
def any(): PropertySet[T]
def get(node: T): PropertySet[T]
def childrenConstraintSets(constraintSet: PropertySet[T], node: T): Seq[PropertySet[T]]
}

private object PropertySetFactory {
def apply[T <: AnyRef](ras: Ras[T]): PropertySetFactory[T] = new PropertySetFactoryImpl[T](ras)

private class PropertySetFactoryImpl[T <: AnyRef](val ras: Ras[T])
def apply[T <: AnyRef](
propertyModel: PropertyModel[T],
planModel: PlanModel[T]): PropertySetFactory[T] =
new PropertySetFactoryImpl[T](propertyModel, planModel)

private class PropertySetFactoryImpl[T <: AnyRef](
propertyModel: PropertyModel[T],
planModel: PlanModel[T])
extends PropertySetFactory[T] {
private val propDefs: Seq[PropertyDef[T, _ <: Property[T]]] = ras.propertyModel.propertyDefs
private val propDefs: Seq[PropertyDef[T, _ <: Property[T]]] = propertyModel.propertyDefs
private val anyConstraint = {
val m: Map[PropertyDef[T, _ <: Property[T]], Property[T]] =
propDefs.map(propDef => (propDef, propDef.any())).toMap
PropertySet[T](m)
}

override def any(): PropertySet[T] = anyConstraint

override def get(node: T): PropertySet[T] = {
val m: Map[PropertyDef[T, _ <: Property[T]], Property[T]] =
Expand All @@ -213,7 +227,7 @@ object Ras {
constraintSet: PropertySet[T],
node: T): Seq[PropertySet[T]] = {
val builder: Seq[mutable.Map[PropertyDef[T, _ <: Property[T]], Property[T]]] =
ras.planModel
planModel
.childrenOf(node)
.map(_ => mutable.Map[PropertyDef[T, _ <: Property[T]], Property[T]]())

Expand All @@ -236,4 +250,20 @@ object Ras {
}
}
}

trait UnsafeKey[T]

private object UnsafeKey {
def apply[T <: AnyRef](ras: Ras[T], self: T): UnsafeKey[T] = new UnsafeKeyImpl(ras, self)
private class UnsafeKeyImpl[T <: AnyRef](ras: Ras[T], val self: T) extends UnsafeKey[T] {
override def hashCode(): Int = ras.planModel.hashCode(self)
override def equals(other: Any): Boolean = {
other match {
case that: UnsafeKeyImpl[T] => ras.planModel.equals(self, that.self)
case _ => false
}
}
override def toString: String = ras.explain.describeNode(self)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.gluten.ras

import org.apache.gluten.ras.Ras.UnsafeKey
import org.apache.gluten.ras.memo.MemoTable
import org.apache.gluten.ras.property.PropertySet

Expand Down Expand Up @@ -54,16 +55,19 @@ object RasCluster {
override val ras: Ras[T],
metadata: Metadata)
extends MutableRasCluster[T] {
private val buffer: mutable.Set[CanonicalNode[T]] =
mutable.Set()
private val deDup: mutable.Set[UnsafeKey[T]] = mutable.Set()
private val buffer: mutable.ListBuffer[CanonicalNode[T]] =
mutable.ListBuffer()

override def contains(t: CanonicalNode[T]): Boolean = {
buffer.contains(t)
deDup.contains(t.toUnsafeKey())
}

override def add(t: CanonicalNode[T]): Unit = {
val key = t.toUnsafeKey()
assert(!deDup.contains(key))
ras.metadataModel.verify(metadata, ras.metadataModel.metadataOf(t.self()))
assert(!buffer.contains(t))
deDup += key
buffer += t
}

Expand All @@ -75,12 +79,12 @@ object RasCluster {

case class ImmutableRasCluster[T <: AnyRef] private (
ras: Ras[T],
override val nodes: Set[CanonicalNode[T]])
override val nodes: Seq[CanonicalNode[T]])
extends RasCluster[T]

object ImmutableRasCluster {
def apply[T <: AnyRef](ras: Ras[T], cluster: RasCluster[T]): ImmutableRasCluster[T] = {
ImmutableRasCluster(ras, cluster.nodes().toSet)
ImmutableRasCluster(ras, cluster.nodes().toVector)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package org.apache.gluten.ras

import org.apache.gluten.ras.Ras.UnsafeKey
import org.apache.gluten.ras.property.PropertySet

trait RasNode[T <: AnyRef] {
Expand All @@ -41,6 +42,8 @@ object RasNode {
def asGroup(): GroupNode[T] = {
node.asInstanceOf[GroupNode[T]]
}

def toUnsafeKey(): UnsafeKey[T] = node.ras().toUnsafeKey(node.self())
}
}

Expand All @@ -53,7 +56,7 @@ object CanonicalNode {
assert(ras.isCanonical(canonical))
val propSet = ras.propSetsOf(canonical)
val children = ras.planModel.childrenOf(canonical)
CanonicalNodeImpl[T](ras, canonical, propSet, children.size)
new CanonicalNodeImpl[T](ras, canonical, propSet, children.size)
}

// We put RasNode's API methods that accept mutable input in implicit definition.
Expand All @@ -74,12 +77,16 @@ object CanonicalNode {
}
}

private case class CanonicalNodeImpl[T <: AnyRef](
ras: Ras[T],
private class CanonicalNodeImpl[T <: AnyRef](
override val ras: Ras[T],
override val self: T,
override val propSet: PropertySet[T],
override val childrenCount: Int)
extends CanonicalNode[T]
extends CanonicalNode[T] {
override def toString: String = ras.explain.describeNode(self)
override def hashCode(): Int = throw new UnsupportedOperationException()
override def equals(obj: Any): Boolean = throw new UnsupportedOperationException()
}
}

trait GroupNode[T <: AnyRef] extends RasNode[T] {
Expand All @@ -88,15 +95,19 @@ trait GroupNode[T <: AnyRef] extends RasNode[T] {

object GroupNode {
def apply[T <: AnyRef](ras: Ras[T], group: RasGroup[T]): GroupNode[T] = {
GroupNodeImpl[T](ras, group.self(), group.propSet(), group.id())
new GroupNodeImpl[T](ras, group.self(), group.propSet(), group.id())
}

private case class GroupNodeImpl[T <: AnyRef](
ras: Ras[T],
private class GroupNodeImpl[T <: AnyRef](
override val ras: Ras[T],
override val self: T,
override val propSet: PropertySet[T],
override val groupId: Int)
extends GroupNode[T] {}
extends GroupNode[T] {
override def toString: String = ras.explain.describeNode(self)
override def hashCode(): Int = throw new UnsupportedOperationException()
override def equals(obj: Any): Boolean = throw new UnsupportedOperationException()
}

// We put RasNode's API methods that accept mutable input in implicit definition.
// Do not break this rule during further development.
Expand All @@ -116,8 +127,21 @@ object InGroupNode {
def apply[T <: AnyRef](groupId: Int, node: CanonicalNode[T]): InGroupNode[T] = {
InGroupNodeImpl(groupId, node)
}

private case class InGroupNodeImpl[T <: AnyRef](groupId: Int, can: CanonicalNode[T])
extends InGroupNode[T]

trait HashKey extends Any

implicit class InGroupNodeImplicits[T <: AnyRef](n: InGroupNode[T]) {
import InGroupNodeImplicits._
def toHashKey: HashKey =
InGroupNodeHashKeyImpl(n.groupId, System.identityHashCode(n.can))
}

private object InGroupNodeImplicits {
private case class InGroupNodeHashKeyImpl(gid: Int, cid: Int) extends HashKey
}
}

trait InClusterNode[T <: AnyRef] {
Expand All @@ -129,8 +153,21 @@ object InClusterNode {
def apply[T <: AnyRef](clusterId: RasClusterKey, node: CanonicalNode[T]): InClusterNode[T] = {
InClusterNodeImpl(clusterId, node)
}

private case class InClusterNodeImpl[T <: AnyRef](
clusterKey: RasClusterKey,
can: CanonicalNode[T])
extends InClusterNode[T]

trait HashKey extends Any

implicit class InClusterNodeImplicits[T <: AnyRef](n: InClusterNode[T]) {
import InClusterNodeImplicits._
def toHashKey: HashKey =
InClusterNodeHashKeyImpl(n.clusterKey, System.identityHashCode(n.can))
}

private object InClusterNodeImplicits {
private case class InClusterNodeHashKeyImpl(clusterKey: RasClusterKey, cid: Int) extends HashKey
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ object RasPlanner {
trait Best[T <: AnyRef] {
import Best._
def rootGroupId(): Int
def bestNodes(): Set[InGroupNode[T]]
def winnerNodes(): Set[InGroupNode[T]]
def bestNodes(): InGroupNode[T] => Boolean
def winnerNodes(): InGroupNode[T] => Boolean
def costs(): InGroupNode[T] => Option[Cost]
def path(): KnownCostPath[T]
}
Expand All @@ -62,11 +62,11 @@ object Best {
bestPath: KnownCostPath[T],
winnerNodes: Seq[InGroupNode[T]],
costs: InGroupNode[T] => Option[Cost]): Best[T] = {
val bestNodes = mutable.Set[InGroupNode[T]]()
val bestNodes = mutable.Set[InGroupNode.HashKey]()

def dfs(groupId: Int, cursor: RasPath.PathNode[T]): Unit = {
val can = cursor.self().asCanonical()
bestNodes += InGroupNode(groupId, can)
bestNodes += InGroupNode(groupId, can).toHashKey
cursor.zipChildrenWithGroupIds().foreach {
case (childPathNode, childGroupId) =>
dfs(childGroupId, childPathNode)
Expand All @@ -75,17 +75,24 @@ object Best {

dfs(rootGroupId, bestPath.rasPath.node())

val winnerNodeSet = winnerNodes.toSet
val bestNodeSet = bestNodes.toSet
val winnerNodeSet = winnerNodes.map(_.toHashKey).toSet

BestImpl(ras, rootGroupId, bestPath, bestNodes.toSet, winnerNodeSet, costs)
BestImpl(
ras,
rootGroupId,
bestPath,
n => bestNodeSet.contains(n.toHashKey),
n => winnerNodeSet.contains(n.toHashKey),
costs)
}

private case class BestImpl[T <: AnyRef](
ras: Ras[T],
override val rootGroupId: Int,
override val path: KnownCostPath[T],
override val bestNodes: Set[InGroupNode[T]],
override val winnerNodes: Set[InGroupNode[T]],
override val bestNodes: InGroupNode[T] => Boolean,
override val winnerNodes: InGroupNode[T] => Boolean,
override val costs: InGroupNode[T] => Option[Cost])
extends Best[T]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,12 @@ object BestFinder {
}

case class KnownCostGroup[T <: AnyRef](
nodeToCost: Map[CanonicalNode[T], KnownCostPath[T]],
nodes: Iterable[CanonicalNode[T]],
nodeToCost: CanonicalNode[T] => Option[KnownCostPath[T]],
bestNode: CanonicalNode[T]) {
def best(): KnownCostPath[T] = nodeToCost(bestNode)
def best(): KnownCostPath[T] = {
nodeToCost(bestNode).get
}
}

case class KnownCostCluster[T <: AnyRef](groupToCost: Map[Int, KnownCostGroup[T]])
Expand All @@ -52,17 +55,21 @@ object BestFinder {
allGroups: Seq[RasGroup[T]],
group: RasGroup[T],
groupToCosts: Map[Int, KnownCostGroup[T]]): Best[T] = {

val bestPath = groupToCosts(group.id()).best()
val bestRoot = bestPath.rasPath.node()
val winnerNodes = groupToCosts.map { case (id, g) => InGroupNode(id, g.bestNode) }.toSeq
val costsMap = mutable.Map[InGroupNode[T], Cost]()
val costsMap = mutable.Map[InGroupNode.HashKey, Cost]()
groupToCosts.foreach {
case (gid, g) =>
g.nodeToCost.foreach {
case (n, c) =>
costsMap += (InGroupNode(gid, n) -> c.cost)
g.nodes.foreach {
n =>
val c = g.nodeToCost(n)
if (c.nonEmpty) {
costsMap += (InGroupNode(gid, n).toHashKey -> c.get.cost)
}
}
}
Best(ras, group.id(), bestPath, winnerNodes, costsMap.get)
Best(ras, group.id(), bestPath, winnerNodes, ign => costsMap.get(ign.toHashKey))
}
}
Loading

0 comments on commit 993e96a

Please sign in to comment.