Skip to content

Commit

Permalink
[VL] RAS: Incorporate query plan's logical link into metadata model (a…
Browse files Browse the repository at this point in the history
  • Loading branch information
zhztheplayer authored Jun 24, 2024
1 parent 1e06169 commit 4ed161b
Show file tree
Hide file tree
Showing 8 changed files with 150 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ object RemoveFilter extends RasRule[SparkPlan] {
val filter = node.asInstanceOf[FilterExecTransformerBase]
if (filter.isNoop()) {
val out = NoopFilter(filter.child, filter.output)
out.copyTagsFrom(filter)
return List(out)
}
List.empty
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,42 +18,18 @@ package org.apache.gluten.planner.metadata

import org.apache.gluten.ras.Metadata

import org.apache.spark.sql.catalyst.expressions.Attribute

sealed trait GlutenMetadata extends Metadata {
import GlutenMetadata._
def schema(): Schema
def logicalLink(): LogicalLink
}

object GlutenMetadata {
def apply(schema: Schema): Metadata = {
Impl(schema)
def apply(schema: Schema, logicalLink: LogicalLink): Metadata = {
Impl(schema, logicalLink)
}

private case class Impl(override val schema: Schema) extends GlutenMetadata

case class Schema(output: Seq[Attribute]) {
private val hash = output.map(_.semanticHash()).hashCode()

override def hashCode(): Int = {
hash
}

override def equals(obj: Any): Boolean = obj match {
case other: Schema =>
semanticEquals(other)
case _ =>
false
}

private def semanticEquals(other: Schema): Boolean = {
if (output.size != other.output.size) {
return false
}
output.zip(other.output).forall {
case (left, right) =>
left.semanticEquals(right)
}
}
private case class Impl(override val schema: Schema, override val logicalLink: LogicalLink)
extends GlutenMetadata {
override def toString: String = s"$schema,$logicalLink"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
*/
package org.apache.gluten.planner.metadata

import org.apache.gluten.planner.metadata.GlutenMetadata.Schema
import org.apache.gluten.planner.plan.GlutenPlanModel.GroupLeafExec
import org.apache.gluten.ras.{Metadata, MetadataModel}

Expand All @@ -31,18 +30,22 @@ object GlutenMetadataModel extends Logging {
private object MetadataModelImpl extends MetadataModel[SparkPlan] {
override def metadataOf(node: SparkPlan): Metadata = node match {
case g: GroupLeafExec => throw new UnsupportedOperationException()
case other => GlutenMetadata(Schema(other.output))
case other =>
GlutenMetadata(
Schema(other.output),
other.logicalLink.map(LogicalLink(_)).getOrElse(LogicalLink.notFound))
}

override def dummy(): Metadata = GlutenMetadata(Schema(List()))
override def dummy(): Metadata = GlutenMetadata(Schema(List()), LogicalLink.notFound)
override def verify(one: Metadata, other: Metadata): Unit = (one, other) match {
case (left: GlutenMetadata, right: GlutenMetadata) if left.schema() != right.schema() =>
// We apply loose restriction on schema. Since Gluten still have some customized
// logics causing schema of an operator to change after being transformed.
// For example: https://github.com/apache/incubator-gluten/pull/5171
logWarning(s"Warning: Schema mismatch: one: ${left.schema()}, other: ${right.schema()}")
case (left: GlutenMetadata, right: GlutenMetadata) if left == right =>
case (left: GlutenMetadata, right: GlutenMetadata) =>
implicitly[Verifier[Schema]].verify(left.schema(), right.schema())
implicitly[Verifier[LogicalLink]].verify(left.logicalLink(), right.logicalLink())
case _ => throw new IllegalStateException(s"Metadata mismatch: one: $one, other $other")
}
}

trait Verifier[T <: Any] {
def verify(one: T, other: T): Unit
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.planner.metadata

import org.apache.gluten.planner.metadata.GlutenMetadataModel.Verifier

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}

case class LogicalLink(plan: LogicalPlan) {
override def hashCode(): Int = System.identityHashCode(plan)
override def equals(obj: Any): Boolean = obj match {
// LogicalLink's comparison is based on ref equality of the logical plans being compared.
case LogicalLink(otherPlan) => plan eq otherPlan
case _ => false
}

override def toString: String = s"${plan.nodeName}[${plan.stats.simpleString}]"
}

object LogicalLink {
private case class LogicalLinkNotFound() extends logical.LeafNode {
override def output: Seq[Attribute] = List.empty
override def canEqual(that: Any): Boolean = throw new UnsupportedOperationException()
override def computeStats(): Statistics = Statistics(sizeInBytes = 0)
}

val notFound = new LogicalLink(LogicalLinkNotFound())
implicit val verifier: Verifier[LogicalLink] = new Verifier[LogicalLink] with Logging {
override def verify(one: LogicalLink, other: LogicalLink): Unit = {
// LogicalLink's comparison is based on ref equality of the logical plans being compared.
if (one != other) {
logWarning(s"Warning: Logical link mismatch: one: $one, other: $other")
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.gluten.planner.metadata

import org.apache.gluten.planner.metadata.GlutenMetadataModel.Verifier

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.Attribute

case class Schema(output: Seq[Attribute]) {
private val hash = output.map(_.semanticHash()).hashCode()

override def hashCode(): Int = {
hash
}

override def equals(obj: Any): Boolean = obj match {
case other: Schema =>
semanticEquals(other)
case _ =>
false
}

private def semanticEquals(other: Schema): Boolean = {
if (output.size != other.output.size) {
return false
}
output.zip(other.output).forall {
case (left, right) =>
left.semanticEquals(right)
}
}

override def toString: String = {
output.toString()
}
}

object Schema {
implicit val verifier: Verifier[Schema] = new Verifier[Schema] with Logging {
override def verify(one: Schema, other: Schema): Unit = {
if (one != other) {
// We apply loose restriction on schema. Since Gluten still have some customized
// logics causing schema of an operator to change after being transformed.
// For example: https://github.com/apache/incubator-gluten/pull/5171
logWarning(s"Warning: Schema mismatch: one: $one, other: $other")
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ case class ConvEnforcerRule(reqConv: Conv) extends RasRule[SparkPlan] {
}
val transition = Conv.findTransition(conv, reqConv)
val after = transition.apply(node)
after.copyTagsFrom(node)
List(after)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,23 @@ private object GroupBasedBestFinder {
return Some(KnownCostPath(ras, path))
}
val childrenGroups = can.getChildrenGroups(allGroups).map(gn => allGroups(gn.groupId()))
val maybeBestChildrenPaths: Seq[Option[RasPath[T]]] = childrenGroups.map {
childGroup => childrenGroupsOutput(childGroup).map(kcg => kcg.best().rasPath)
val maybeBestChildrenPaths: Seq[Option[KnownCostPath[T]]] = childrenGroups.map {
childGroup => childrenGroupsOutput(childGroup).map(kcg => kcg.best())
}
if (maybeBestChildrenPaths.exists(_.isEmpty)) {
// Node should only be solved when all children outputs exist.
return None
}
val bestChildrenPaths = maybeBestChildrenPaths.map(_.get)
Some(KnownCostPath(ras, path.RasPath(ras, can, bestChildrenPaths).get))
val kcp = KnownCostPath(ras, path.RasPath(ras, can, bestChildrenPaths.map(_.rasPath)).get)
// Cost should be in monotonically increasing basis.
bestChildrenPaths.map(_.cost).foreach {
childCost =>
assert(
ras.costModel.costComparator().gteq(kcp.cost, childCost),
"Illegal decreasing cost")
}
Some(kcp)
}

override def solveGroup(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ class OperationSuite extends AnyFunSuite {
48,
Unary3(48, Unary3(48, Unary3(48, Unary3(48, Unary3(48, Unary3(48, Leaf(30))))))))))))
assert(costModel.costOfCount == 32) // TODO reduce this for performance
assert(costModel.costCompareCount == 20) // TODO reduce this for performance
assert(costModel.costCompareCount == 50) // TODO reduce this for performance
}

test("Cost evaluation count - max cost") {
Expand Down Expand Up @@ -292,7 +292,7 @@ class OperationSuite extends AnyFunSuite {
48,
Unary3(48, Unary3(48, Unary3(48, Unary3(48, Unary3(48, Unary3(48, Leaf(30))))))))))))
assert(costModel.costOfCount == 32) // TODO reduce this for performance
assert(costModel.costCompareCount == 20) // TODO reduce this for performance
assert(costModel.costCompareCount == 50) // TODO reduce this for performance
}
}

Expand Down

0 comments on commit 4ed161b

Please sign in to comment.