Skip to content

Commit

Permalink
[VL] RAS: New rule RemoveSort to remove unnecessary sorts (#6107)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhztheplayer authored Jun 18, 2024
1 parent a722af3 commit eb653ba
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ case class EnumeratedTransform(session: SparkSession, outputsColumnar: Boolean)

private val rules = List(
new PushFilterToScan(RasOffload.validator),
RemoveSort,
RemoveFilter
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.extension.columnar.enumerated

import org.apache.gluten.execution.{HashAggregateExecBaseTransformer, ShuffledHashJoinExecTransformerBase, SortExecTransformer}
import org.apache.gluten.extension.GlutenPlan
import org.apache.gluten.ras.path.Pattern._
import org.apache.gluten.ras.path.Pattern.Matchers._
import org.apache.gluten.ras.rule.{RasRule, Shape}
import org.apache.gluten.ras.rule.Shapes._

import org.apache.spark.sql.catalyst.expressions.SortOrder
import org.apache.spark.sql.execution.SparkPlan

/**
* Removes unnecessary sort if its parent doesn't require for sorted input.
*
* TODO: Sort's removal could be made much simpler once output ordering is added as a physical
* property in RAS planer.
*/
object RemoveSort extends RasRule[SparkPlan] {
private val appliedTypes: Seq[Class[_ <: GlutenPlan]] =
List(classOf[HashAggregateExecBaseTransformer], classOf[ShuffledHashJoinExecTransformerBase])

override def shift(node: SparkPlan): Iterable[SparkPlan] = {
assert(node.isInstanceOf[GlutenPlan])
val newChildren = node.requiredChildOrdering.zip(node.children).map {
case (Nil, sort: SortExecTransformer) =>
// Parent doesn't ask for sorted input from this child but a sort op was somehow added.
// Remove it.
sort.child
case (req, child) =>
// Parent asks for sorted input from this child. Do nothing but an assertion.
assert(SortOrder.orderingSatisfies(child.outputOrdering, req))
child
}
val out = List(node.withNewChildren(newChildren))
out
}
override def shape(): Shape[SparkPlan] = pattern(
branch2[SparkPlan](
or(appliedTypes.map(clazz[SparkPlan](_)): _*),
_ >= 1,
_ => node(clazz(classOf[GlutenPlan]))
).build()
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,23 +87,51 @@ object Pattern {
override def children(count: Int): Seq[Node[T]] = (0 until count).map(_ => ignore[T])
}

private case class Branch[T <: AnyRef](matcher: Matcher[T], children: Seq[Node[T]])
private case class Branch[T <: AnyRef](matcher: Matcher[T], children: Branch.ChildrenFactory[T])
extends Node[T] {
override def skip(): Boolean = false
override def abort(node: CanonicalNode[T]): Boolean = node.childrenCount != children.size
override def abort(node: CanonicalNode[T]): Boolean =
!children.acceptsChildrenCount(node.childrenCount)
override def matches(node: CanonicalNode[T]): Boolean = matcher(node.self())
override def children(count: Int): Seq[Node[T]] = {
assert(count == children.size)
children
assert(children.acceptsChildrenCount(count))
(0 until count).map(children.child)
}
}

private object Branch {
trait ChildrenFactory[T <: AnyRef] {
def child(index: Int): Node[T]
def acceptsChildrenCount(count: Int): Boolean
}

object ChildrenFactory {
case class Plain[T <: AnyRef](nodes: Seq[Node[T]]) extends ChildrenFactory[T] {
override def child(index: Int): Node[T] = nodes(index)
override def acceptsChildrenCount(count: Int): Boolean = nodes.size == count
}

case class Func[T <: AnyRef](arity: Int => Boolean, func: Int => Node[T])
extends ChildrenFactory[T] {
override def child(index: Int): Node[T] = func(index)
override def acceptsChildrenCount(count: Int): Boolean = arity(count)
}
}
}

def any[T <: AnyRef]: Node[T] = Any.INSTANCE.asInstanceOf[Node[T]]
def ignore[T <: AnyRef]: Node[T] = Ignore.INSTANCE.asInstanceOf[Node[T]]
def node[T <: AnyRef](matcher: Matcher[T]): Node[T] = Single(matcher)
def branch[T <: AnyRef](matcher: Matcher[T], children: Node[T]*): Node[T] =
Branch(matcher, children.toSeq)
def leaf[T <: AnyRef](matcher: Matcher[T]): Node[T] = Branch(matcher, List.empty)
Branch(matcher, Branch.ChildrenFactory.Plain(children.toSeq))
// Similar to #branch, but with unknown arity.
def branch2[T <: AnyRef](
matcher: Matcher[T],
arity: Int => Boolean,
children: Int => Node[T]): Node[T] =
Branch(matcher, Branch.ChildrenFactory.Func(arity, children))
def leaf[T <: AnyRef](matcher: Matcher[T]): Node[T] =
Branch(matcher, Branch.ChildrenFactory.Plain(List.empty))

implicit class NodeImplicits[T <: AnyRef](node: Node[T]) {
def build(): Pattern[T] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,29 @@ class PatternSuite extends AnyFunSuite {
assert(pattern.matches(path, 1))
}

test("Match branch") {
val ras =
Ras[TestNode](
PlanModelImpl,
CostModelImpl,
MetadataModelImpl,
PropertyModelImpl,
ExplainImpl,
RasRule.Factory.none())

val path1 = MockRasPath.mock(ras, Branch("n1", List()))
val path2 = MockRasPath.mock(ras, Branch("n1", List(Leaf("n2", 1))))
val path3 = MockRasPath.mock(ras, Branch("n1", List(Leaf("n2", 1), Leaf("n3", 1))))

val pattern =
Pattern.branch2[TestNode](n => n.isInstanceOf[Branch], _ >= 1, _ => Pattern.any).build()
assert(!pattern.matches(path1, 1))
assert(pattern.matches(path2, 1))
assert(pattern.matches(path2, 2))
assert(pattern.matches(path3, 1))
assert(pattern.matches(path3, 2))
}

test("Match unary") {
val ras =
Ras[TestNode](
Expand Down Expand Up @@ -231,17 +254,20 @@ object PatternSuite {

case class Unary(name: String, child: TestNode) extends UnaryLike {
override def selfCost(): Long = 1

override def withNewChildren(child: TestNode): UnaryLike = copy(child = child)
}

case class Binary(name: String, left: TestNode, right: TestNode) extends BinaryLike {
override def selfCost(): Long = 1

override def withNewChildren(left: TestNode, right: TestNode): BinaryLike =
copy(left = left, right = right)
}

case class Branch(name: String, children: Seq[TestNode]) extends TestNode {
override def selfCost(): Long = 1
override def withNewChildren(children: Seq[TestNode]): TestNode = copy(children = children)
}

case class DummyGroup() extends LeafLike {
override def makeCopy(): LeafLike = throw new UnsupportedOperationException()
override def selfCost(): Long = throw new UnsupportedOperationException()
Expand Down

0 comments on commit eb653ba

Please sign in to comment.