Skip to content

Commit

Permalink
[CORE] Update Substrait to 0.24.0 (apache#4361)
Browse files Browse the repository at this point in the history
Update Substrait to 0.24.0

Co-authored-by: Ted Jenks <[email protected]>
  • Loading branch information
ted-jenks and Ted Jenks authored Jan 15, 2024
1 parent 6a5c64c commit 3afb4be
Show file tree
Hide file tree
Showing 14 changed files with 60 additions and 39 deletions.
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
<!-- Package name to use when relocating shaded classes -->
<gluten.shade.packageName>io.glutenproject.shaded</gluten.shade.packageName>

<substrait.version>0.5.0</substrait.version>
<substrait.version>0.24.0</substrait.version>
<!-- The original protobuf version -->
<protobuf.version>3.23.4</protobuf.version>
<!-- The custom protobuf version (based on 3.23.4) with recursion limit enlarged. -->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import io.substrait.spark.DefaultExpressionVisitor
import org.apache.spark.sql.catalyst.util.DateTimeUtils

import io.substrait.expression.{Expression, FieldReference}
import io.substrait.expression.Expression.{DateLiteral, DecimalLiteral, StrLiteral}
import io.substrait.expression.Expression.{DateLiteral, DecimalLiteral, I32Literal, StrLiteral}
import io.substrait.function.ToTypeString
import io.substrait.util.DecimalUtil

Expand All @@ -38,9 +38,15 @@ class ExpressionToString extends DefaultExpressionVisitor[String] {
override def visit(expr: StrLiteral): String = {
expr.value()
}

override def visit(expr: I32Literal): String = {
expr.value().toString
}

override def visit(expr: DateLiteral): String = {
DateTimeUtils.toJavaDate(expr.value()).toString
}

override def visit(expr: FieldReference): String = {
withFieldReference(expr)(i => "$" + i.toString)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,10 @@ class RelToVerboseString(addSuffix: Boolean) extends DefaultRelVisitor[String] {
builder.append(", ")
builder.append("filter=").append(filter)
})
read.getGeneralExtension.ifPresent(
generalExtension => {
read.getCommonExtension.ifPresent(
commonExtension => {
builder.append(", ")
builder.append("generalExtension=").append(generalExtension)
builder.append("commonExtension=").append(commonExtension)
})
}
override def visit(namedScan: NamedScan): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package io.substrait.spark

import io.substrait.`type`.Type
import io.substrait.expression._
import io.substrait.function.SimpleExtension
import io.substrait.extension.SimpleExtension

class DefaultExpressionVisitor[T]
extends AbstractExpressionVisitor[T, RuntimeException]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package io.substrait.spark

import io.substrait.spark.expression.ToAggregateFunction

import io.substrait.function.SimpleExtension
import io.substrait.extension.SimpleExtension

import java.util.Collections

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ import org.apache.spark.substrait.ToSubstraitType
import com.google.common.collect.{ArrayListMultimap, Multimap}
import io.substrait.`type`.Type
import io.substrait.expression.{Expression => SExpression, ExpressionCreator, FunctionArg}
import io.substrait.function.{ParameterizedType, SimpleExtension, ToTypeString}
import io.substrait.extension.SimpleExtension
import io.substrait.function.{ParameterizedType, ToTypeString}
import io.substrait.utils.Util

import java.{util => ju}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ class IgnoreNullableAndParameters(val typeToMatch: ParameterizedType)
override def visit(`type`: Type.Map): Boolean =
typeToMatch.isInstanceOf[Type.Map] || typeToMatch.isInstanceOf[ParameterizedType.Map]

override def visit(`type`: Type.UserDefined): Boolean =
typeToMatch.isInstanceOf[Type.UserDefined]

@throws[RuntimeException]
override def visit(expr: ParameterizedType.FixedChar): Boolean =
typeToMatch.isInstanceOf[Type.FixedChar] || typeToMatch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._

import io.substrait.`type`.Type
import io.substrait.expression.{AggregateFunctionInvocation, Expression => SExpression, ExpressionCreator, FunctionArg}
import io.substrait.function.SimpleExtension
import io.substrait.proto.AggregateFunction
import io.substrait.extension.SimpleExtension

import java.util.Collections

Expand Down Expand Up @@ -81,14 +80,14 @@ object ToAggregateFunction {
case SExpression.AggregationPhase.INTERMEDIATE_TO_RESULT => Final
case SExpression.AggregationPhase.INITIAL_TO_RESULT => Complete
}
def fromSpark(isDistinct: Boolean): AggregateFunction.AggregationInvocation = if (isDistinct) {
AggregateFunction.AggregationInvocation.AGGREGATION_INVOCATION_DISTINCT
def fromSpark(isDistinct: Boolean): SExpression.AggregationInvocation = if (isDistinct) {
SExpression.AggregationInvocation.DISTINCT
} else {
AggregateFunction.AggregationInvocation.AGGREGATION_INVOCATION_ALL
SExpression.AggregationInvocation.ALL
}

def toSpark(innovation: AggregateFunction.AggregationInvocation): Boolean = innovation match {
case AggregateFunction.AggregationInvocation.AGGREGATION_INVOCATION_DISTINCT => true
def toSpark(innovation: SExpression.AggregationInvocation): Boolean = innovation match {
case SExpression.AggregationInvocation.DISTINCT => true
case _ => false
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression

import io.substrait.`type`.Type
import io.substrait.expression.{Expression => SExpression, FunctionArg}
import io.substrait.function.SimpleExtension
import io.substrait.extension.SimpleExtension

import scala.collection.JavaConverters

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@ import org.apache.spark.sql.types.StructType
import org.apache.spark.substrait.ToSubstraitType
import org.apache.spark.substrait.ToSubstraitType.toNamedStruct

import io.substrait.`type`.Type
import io.substrait.{proto, relation}
import io.substrait.debug.TreePrinter
import io.substrait.expression.{Expression => SExpression, ExpressionCreator}
import io.substrait.extension.ExtensionCollector
import io.substrait.plan.{ImmutablePlan, ImmutableRoot, Plan}
import io.substrait.relation
import io.substrait.relation.RelProtoConverter

import java.util.Collections

Expand Down Expand Up @@ -210,11 +212,6 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
relation.Cross.builder
.left(left)
.right(right)
.deriveRecordType(
Type.Struct.builder
.from(left.getRecordType)
.from(right.getRecordType)
.build)
.build
} else {
relation.Join.builder
Expand Down Expand Up @@ -251,11 +248,6 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
relation.Sort.builder.addAllSortFields(fields).input(input).build
}

override def visitOffset(plan: Offset): relation.Rel = {
throw new UnsupportedOperationException(
s"Unable to convert the plan to a substrait plan: $plan")
}

private def toExpression(output: Seq[Attribute])(e: Expression): SExpression = {
toSubstraitExp(e, output)
}
Expand Down Expand Up @@ -327,6 +319,27 @@ class ToSubstraitRel extends AbstractLogicalPlanVisitor with Logging {
))
.build()
}

def tree(p: LogicalPlan): String = {
TreePrinter.tree(visit(p))
}

def toProtoSubstrait(p: LogicalPlan): Array[Byte] = {
val substraitRel = visit(p)

val extensionCollector = new ExtensionCollector
val relProtoConverter = new RelProtoConverter(extensionCollector)
val builder = proto.Plan
.newBuilder()
.addRelations(
proto.PlanRel
.newBuilder()
.setRel(substraitRel
.accept(relProtoConverter))
)
extensionCollector.addExtensionsToPlan(builder)
builder.build().toByteArray
}
}

private[logical] class WithLogicalSubQuery(toSubstraitRel: ToSubstraitRel)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,4 @@ class AbstractLogicalPlanVisitor extends LogicalPlanVisitor[relation.Rel] {
override def visitSort(sort: Sort): Rel = t(sort)

override def visitWithCTE(p: WithCTE): Rel = t(p)

def visitOffset(p: Offset): Rel = t(p)
}
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,5 @@ class AbstractLogicalPlanVisitor extends LogicalPlanVisitor[relation.Rel] {

override def visitWithCTE(p: WithCTE): Rel = t(p)

def visitOffset(p: Offset): Rel = t(p)

override def visitRebalancePartitions(p: RebalancePartitions): Rel = t(p)
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.util.resourceToString
import org.apache.spark.sql.test.SharedSparkSession

import io.substrait.debug.TreePrinter
import io.substrait.expression.proto.FunctionCollector
import io.substrait.extension.ExtensionCollector
import io.substrait.plan.{Plan, PlanProtoConverter, ProtoPlanConverter}
import io.substrait.proto
import io.substrait.relation.RelProtoConverter
Expand Down Expand Up @@ -57,8 +57,8 @@ trait SubstraitPlanTestBase { self: SharedSparkSession =>
val logicalPlan = plan(sql)
val substraitRel = convert.visit(logicalPlan)

val functionCollector = new FunctionCollector
val relProtoConverter = new RelProtoConverter(functionCollector)
val extensionCollector = new ExtensionCollector
val relProtoConverter = new RelProtoConverter(extensionCollector)
val builder = proto.Plan
.newBuilder()
.addRelations(
Expand All @@ -71,7 +71,7 @@ trait SubstraitPlanTestBase { self: SharedSparkSession =>
.accept(relProtoConverter))
)
)
functionCollector.addFunctionsToPlan(builder)
extensionCollector.addExtensionsToPlan(builder)
builder.build()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{IntegerType, LongType}

import io.substrait.`type`.Type
import io.substrait.`type`.TypeCreator
import io.substrait.expression.{Expression => SExpression, ExpressionCreator}

class ArithmeticExpressionSuite extends SparkFunSuite with SubstraitExpressionTestBase {
Expand All @@ -31,8 +31,11 @@ class ArithmeticExpressionSuite extends SparkFunSuite with SubstraitExpressionTe
Add(Literal(1), Literal(2L)),
func => {
assertResult(true)(func.arguments().get(1).isInstanceOf[SExpression.I64Literal])
assertResult(ExpressionCreator.cast(Type.REQUIRED.I64, ExpressionCreator.i32(false, 1)))(
func.arguments().get(0))
assertResult(
ExpressionCreator.cast(
TypeCreator.REQUIRED.I64,
ExpressionCreator.i32(false, 1)
))(func.arguments().get(0))
},
bidirectional = false
) // TODO: implicit calcite cast
Expand Down

0 comments on commit 3afb4be

Please sign in to comment.