Skip to content

Commit

Permalink
[GLUTEN-4105][VL] Fix parse substrait plan to json plan error when va…
Browse files Browse the repository at this point in the history
…lidation (#4107)

Use BackendsApiManager.getTransformerApiInstance.packPBMessage instead of Any.pack directly.
  • Loading branch information
Yohahaha authored Dec 23, 2023
1 parent 7fff4e5 commit 1fd405e
Show file tree
Hide file tree
Showing 17 changed files with 51 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import io.glutenproject.substrait.plan.PlanBuilder;
import io.glutenproject.substrait.plan.PlanNode;

import com.google.protobuf.Any;
import org.apache.spark.SparkConf;
import org.apache.spark.sql.internal.SQLConf;

Expand Down Expand Up @@ -72,7 +71,9 @@ public boolean doValidate(byte[] subPlan) {
private PlanNode buildNativeConfNode(Map<String, String> confs) {
StringMapNode stringMapNode = ExpressionBuilder.makeStringMap(confs);
AdvancedExtensionNode extensionNode =
ExtensionBuilder.makeAdvancedExtension(Any.pack(stringMapNode.toProtobuf()));
ExtensionBuilder.makeAdvancedExtension(
BackendsApiManager.getTransformerApiInstance()
.packPBMessage(stringMapNode.toProtobuf()));
return PlanBuilder.makePlan(extensionNode);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,5 +234,5 @@ class CHTransformerApi extends TransformerApi with Logging {
throw new UnsupportedOperationException("CH backend does not support this method")
}

override def getPackMessage(message: Message): Any = Any.pack(message)
override def packPBMessage(message: Message): Any = Any.pack(message)
}
Original file line number Diff line number Diff line change
Expand Up @@ -96,5 +96,5 @@ class TransformerApiImpl extends TransformerApi with Logging {
}
}

override def getPackMessage(message: Message): Any = Any.pack(message, "")
override def packPBMessage(message: Message): Any = Any.pack(message, "")
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package io.glutenproject.execution

import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.expression._
import io.glutenproject.expression.ConverterUtils.FunctionConfig
import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode}
Expand All @@ -31,8 +32,6 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

import com.google.protobuf.Any

import java.lang.{Long => JLong}
import java.util.{ArrayList => JArrayList, HashMap => JHashMap, List => JList}

Expand Down Expand Up @@ -162,7 +161,8 @@ case class HashAggregateExecTransformer(
groupingExpressions.size + aggregateExpressions.size)
} else {
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
Any.pack(TypeBuilder.makeStruct(false, getPartialAggOutTypes).toProtobuf))
BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, getPartialAggOutTypes).toProtobuf))
RelBuilder.makeProjectRel(
aggRel,
expressionNodes,
Expand Down Expand Up @@ -435,7 +435,8 @@ case class HashAggregateExecTransformer(
.map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
.asJava
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
RelBuilder.makeProjectRel(
inputRel,
exprNodes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
*/
package io.glutenproject.substrait.rel;

import com.google.protobuf.Any;
import io.glutenproject.backendsapi.BackendsApiManager;

import com.google.protobuf.StringValue;
import io.substrait.proto.ReadRel;

Expand Down Expand Up @@ -69,7 +70,8 @@ public ReadRel.ExtensionTable toProtobuf() {
ReadRel.ExtensionTable.Builder extensionTableBuilder = ReadRel.ExtensionTable.newBuilder();
StringValue extensionTable =
StringValue.newBuilder().setValue(extensionTableStr.toString()).build();
extensionTableBuilder.setDetail(Any.pack(extensionTable));
extensionTableBuilder.setDetail(
BackendsApiManager.getTransformerApiInstance().packPBMessage(extensionTable));
return extensionTableBuilder.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,5 +74,5 @@ trait TransformerApi {

def getNativePlanString(substraitPlan: Array[Byte], details: Boolean): String

def getPackMessage(message: Message): Any
def packPBMessage(message: Message): Any
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV
import org.apache.spark.sql.utils.StructTypeFWD
import org.apache.spark.sql.vectorized.ColumnarBatch

import com.google.protobuf.Any

import scala.collection.JavaConverters._

abstract class FilterExecTransformerBase(val cond: Expression, val input: SparkPlan)
Expand Down Expand Up @@ -88,7 +86,8 @@ abstract class FilterExecTransformerBase(val cond: Expression, val input: SparkP
.map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
.asJava
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
RelBuilder.makeFilterRel(input, condExprNode, extensionNode, context, operatorId)
}
}
Expand Down Expand Up @@ -221,7 +220,8 @@ case class ProjectExecTransformer private (projectList: Seq[NamedExpression], ch
.map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
.asJava
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
RelBuilder.makeProjectRel(
input,
projExprNodeList,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
import org.apache.spark.sql.execution._

import com.google.protobuf.Any

import java.util.{ArrayList => JArrayList, List => JList}

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -115,7 +113,8 @@ case class ExpandExecTransformer(
inputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
}
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
RelBuilder.makeProjectRel(
input,
preExprNodes,
Expand Down Expand Up @@ -167,7 +166,8 @@ case class ExpandExecTransformer(
}

val extensionNode = ExtensionBuilder.makeAdvancedExtension(
Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
RelBuilder.makeExpandRel(input, projectSetExprNodes, extensionNode, context, operatorId)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ import io.glutenproject.substrait.rel.{RelBuilder, RelNode}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan

import com.google.protobuf.Any

import java.util.{ArrayList => JArrayList, List => JList}

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -164,7 +162,8 @@ case class GenerateExecTransformer(
val inputTypeNodeList =
inputAttributes.map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)).asJava
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
RelBuilder.makeGenerateRel(input, generator, childOutput, extensionNode, context, operatorId)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.aggregate._
import org.apache.spark.sql.types._

import com.google.protobuf.{Any, StringValue}
import com.google.protobuf.StringValue

import java.util.{ArrayList => JArrayList, List => JList}

Expand Down Expand Up @@ -266,7 +266,8 @@ abstract class HashAggregateExecBaseTransformer(
.map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
.asJava
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
RelBuilder.makeProjectRel(
input,
preExprNodes,
Expand Down Expand Up @@ -379,7 +380,8 @@ abstract class HashAggregateExecBaseTransformer(
.map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
.asJava
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
RelBuilder.makeProjectRel(
aggRel,
resExprNodes,
Expand Down Expand Up @@ -557,7 +559,8 @@ abstract class HashAggregateExecBaseTransformer(
val inputTypeNodeList = originalInputAttributes
.map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
.asJava
Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)
BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)
} else {
null
}
Expand All @@ -568,7 +571,7 @@ abstract class HashAggregateExecBaseTransformer(
"0"
}
val optimization =
BackendsApiManager.getTransformerApiInstance.getPackMessage(
BackendsApiManager.getTransformerApiInstance.packPBMessage(
StringValue.newBuilder.setValue(s"isStreaming=$isStreaming\n").build)
ExtensionBuilder.makeAdvancedExtension(optimization, enhancement)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ trait HashJoinLikeExecTransformer
.newBuilder()
.setValue(joinParametersStr.toString)
.build()
BackendsApiManager.getTransformerApiInstance.getPackMessage(message)
BackendsApiManager.getTransformerApiInstance.packPBMessage(message)
}

def genJoinParametersInternal(): (Int, Int, String) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ object JoinUtils {
// Normally the enhancement node is only used for plan validation. But here the enhancement
// is also used in execution phase. In this case an empty typeUrlPrefix need to be passed,
// so that it can be correctly parsed into json string on the cpp side.
BackendsApiManager.getTransformerApiInstance.getPackMessage(
BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, inputTypeNodes.asJava).toProtobuf)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ import io.glutenproject.substrait.rel.{RelBuilder, RelNode}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.execution.SparkPlan

import com.google.protobuf.Any

import scala.collection.JavaConverters._

case class LimitTransformer(child: SparkPlan, offset: Long, count: Long)
Expand Down Expand Up @@ -81,7 +79,8 @@ case class LimitTransformer(child: SparkPlan, offset: Long, count: Long)
val inputTypeNodes =
inputAttributes.map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)).asJava
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
Any.pack(TypeBuilder.makeStruct(false, inputTypeNodes).toProtobuf))
BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, inputTypeNodes).toProtobuf))
RelBuilder.makeFetchRel(input, offset, count, extensionNode, context, operatorId)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution._

import com.google.protobuf.Any
import io.substrait.proto.SortField

import java.util.{ArrayList => JArrayList}
Expand Down Expand Up @@ -112,7 +111,8 @@ case class SortExecTransformer(
}

val extensionNode = ExtensionBuilder.makeAdvancedExtension(
Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
RelBuilder.makeProjectRel(
input,
projectExpressions,
Expand All @@ -136,7 +136,8 @@ case class SortExecTransformer(

}
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))

RelBuilder.makeSortRel(inputRel, sortFieldList, extensionNode, context, operatorId)
}
Expand All @@ -157,7 +158,8 @@ case class SortExecTransformer(
}

val extensionNode = ExtensionBuilder.makeAdvancedExtension(
Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
RelBuilder.makeProjectRel(
sortRel,
new JArrayList[ExpressionNode](selectOrigins),
Expand Down Expand Up @@ -195,7 +197,8 @@ case class SortExecTransformer(
val inputTypeNodeList = originalInputAttributes.map(
attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList.asJava).toProtobuf))
BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, inputTypeNodeList.asJava).toProtobuf))

RelBuilder.makeSortRel(input, sortFieldList.asJava, extensionNode, context, operatorId)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ case class SortMergeJoinExecTransformer(
.newBuilder()
.setValue(joinParametersStr.toString)
.build()
BackendsApiManager.getTransformerApiInstance.getPackMessage(message)
BackendsApiManager.getTransformerApiInstance.packPBMessage(message)
}

// Direct output order of substrait join operation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ case class WindowExecTransformer(
.newBuilder()
.setValue(windowParametersStr.toString)
.build()
BackendsApiManager.getTransformerApiInstance.getPackMessage(message)
BackendsApiManager.getTransformerApiInstance.packPBMessage(message)
}

def getRelNode(
Expand Down Expand Up @@ -168,7 +168,8 @@ case class WindowExecTransformer(
.map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
.asJava
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))

RelBuilder.makeWindowRel(
input,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.python.EvalPythonExec
import org.apache.spark.sql.types.StructType

import com.google.protobuf.Any

import java.util.{ArrayList => JArrayList, List => JList}

case class EvalPythonExecTransformer(
Expand Down Expand Up @@ -120,7 +118,8 @@ case class EvalPythonExecTransformer(
inputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable))
}
val extensionNode = ExtensionBuilder.makeAdvancedExtension(
Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
BackendsApiManager.getTransformerApiInstance.packPBMessage(
TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf))
RelBuilder.makeProjectRel(input, expressionNodes, extensionNode, context, operatorId, -1)
}
}
Expand Down

0 comments on commit 1fd405e

Please sign in to comment.