Skip to content

Commit

Permalink
[CORE] Fix exception of pb MessageToJsonString (apache#3823)
Browse files Browse the repository at this point in the history
  • Loading branch information
exmy authored Nov 28, 2023
1 parent a55ab51 commit 8518680
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import org.apache.spark.sql.types._
import org.apache.spark.util.collection.BitSet

import com.google.common.collect.Lists
import com.google.protobuf.{Any, Message}

import java.util

Expand Down Expand Up @@ -267,4 +268,6 @@ class CHTransformerApi extends TransformerApi with Logging {
override def getNativePlanString(substraitPlan: Array[Byte], details: Boolean): String = {
throw new UnsupportedOperationException("CH backend does not support this method")
}

override def getPackMessage(message: Message): Any = Any.pack(message)
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, PartitionDi
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.BitSet

import com.google.protobuf.{Any, Message}

import java.util.{Map => JMap}

class TransformerApiImpl extends TransformerApi with Logging {
Expand Down Expand Up @@ -135,4 +137,6 @@ class TransformerApiImpl extends TransformerApi with Logging {
tmpRuntime.release()
}
}

override def getPackMessage(message: Message): Any = Any.pack(message, "")
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, PartitionDi
import org.apache.spark.sql.types.DecimalType
import org.apache.spark.util.collection.BitSet

import com.google.protobuf.{Any, Message}

import java.util

trait TransformerApi {
Expand Down Expand Up @@ -85,4 +87,6 @@ trait TransformerApi {
nullOnOverflow: Boolean): ExpressionNode

def getNativePlanString(substraitPlan: Array[Byte], details: Boolean): String

def getPackMessage(message: Message): Any
}
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ trait HashJoinLikeExecTransformer
substraitJoinType,
needSwitchChildren,
joinType,
genJoinParametersBuilder(),
genJoinParameters(),
null,
null,
streamedPlan.output,
Expand Down Expand Up @@ -278,7 +278,7 @@ trait HashJoinLikeExecTransformer
substraitJoinType,
needSwitchChildren,
joinType,
genJoinParametersBuilder(),
genJoinParameters(),
inputStreamedRelNode,
inputBuildRelNode,
inputStreamedOutput,
Expand All @@ -297,8 +297,8 @@ trait HashJoinLikeExecTransformer
inputBuildOutput)
}

def genJoinParametersBuilder(): Any.Builder = {
val (isBHJ, isNullAwareAntiJoin, buildHashTableId) = genJoinParameters()
def genJoinParameters(): Any = {
val (isBHJ, isNullAwareAntiJoin, buildHashTableId) = genJoinParametersInternal()
// Start with "JoinParameters:"
val joinParametersStr = new StringBuffer("JoinParameters:")
// isBHJ: 0 for SHJ, 1 for BHJ
Expand All @@ -321,12 +321,10 @@ trait HashJoinLikeExecTransformer
.newBuilder()
.setValue(joinParametersStr.toString)
.build()
Any.newBuilder
.setValue(message.toByteString)
.setTypeUrl("/google.protobuf.StringValue")
BackendsApiManager.getTransformerApiInstance.getPackMessage(message)
}

def genJoinParameters(): (Int, Int, String) = {
def genJoinParametersInternal(): (Int, Int, String) = {
(0, 0, "")
}
}
Expand Down Expand Up @@ -406,7 +404,7 @@ abstract class BroadcastHashJoinExecTransformer(
// Unique ID for builded hash table
lazy val buildHashTableId: String = "BuiltHashTable-" + buildPlan.id

override def genJoinParameters(): (Int, Int, String) = {
override def genJoinParametersInternal(): (Int, Int, String) = {
(1, if (isNullAwareAntiJoin) 1 else 0, buildHashTableId)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package io.glutenproject.execution

import io.glutenproject.backendsapi.BackendsApiManager
import io.glutenproject.expression.{AttributeReferenceTransformer, ConverterUtils, ExpressionConverter}
import io.glutenproject.substrait.`type`.TypeBuilder
import io.glutenproject.substrait.SubstraitContext
Expand All @@ -40,9 +41,8 @@ object JoinUtils {
// Normally the enhancement node is only used for plan validation. But here the enhancement
// is also used in execution phase. In this case an empty typeUrlPrefix need to be passed,
// so that it can be correctly parsed into json string on the cpp side.
Any.pack(
TypeBuilder.makeStruct(false, inputTypeNodes.asJava).toProtobuf,
/* typeUrlPrefix */ "")
BackendsApiManager.getTransformerApiInstance.getPackMessage(
TypeBuilder.makeStruct(false, inputTypeNodes.asJava).toProtobuf)
}

def createExtensionNode(output: Seq[Attribute], validation: Boolean): AdvancedExtensionNode = {
Expand Down Expand Up @@ -132,12 +132,12 @@ object JoinUtils {
}

def createJoinExtensionNode(
joinParameters: Any.Builder,
joinParameters: Any,
output: Seq[Attribute]): AdvancedExtensionNode = {
// Use field [optimization] in a extension node
// to send some join parameters through Substrait plan.
val enhancement = createEnhancement(output)
ExtensionBuilder.makeAdvancedExtension(joinParameters.build(), enhancement)
ExtensionBuilder.makeAdvancedExtension(joinParameters, enhancement)
}

// Return the direct join output.
Expand Down Expand Up @@ -180,7 +180,7 @@ object JoinUtils {
substraitJoinType: JoinRel.JoinType,
exchangeTable: Boolean,
joinType: JoinType,
joinParameters: Any.Builder,
joinParameters: Any,
inputStreamedRelNode: RelNode,
inputBuildRelNode: RelNode,
inputStreamedOutput: Seq[Attribute],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.vectorized.ColumnarBatch

import com.google.protobuf.StringValue
import com.google.protobuf.{Any, StringValue}
import io.substrait.proto.JoinRel

import scala.collection.JavaConverters._
Expand All @@ -55,7 +55,7 @@ case class SortMergeJoinExecTransformer(
val (bufferedKeys, streamedKeys, bufferedPlan, streamedPlan) =
(rightKeys, leftKeys, right, left)

override def stringArgs: Iterator[Any] = super.stringArgs.toSeq.dropRight(1).iterator
override def stringArgs: Iterator[scala.Any] = super.stringArgs.toSeq.dropRight(1).iterator

override def simpleStringWithNodeId(): String = {
val opId = ExplainUtils.getOpId(this)
Expand Down Expand Up @@ -176,7 +176,7 @@ case class SortMergeJoinExecTransformer(
override def metricsUpdater(): MetricsUpdater =
BackendsApiManager.getMetricsApiInstance.genSortMergeJoinTransformerMetricsUpdater(metrics)

def genJoinParametersBuilder(): com.google.protobuf.Any.Builder = {
def genJoinParameters(): Any = {
val (isSMJ, isNullAwareAntiJoin) = (1, 0)
// Start with "JoinParameters:"
val joinParametersStr = new StringBuffer("JoinParameters:")
Expand All @@ -196,9 +196,7 @@ case class SortMergeJoinExecTransformer(
.newBuilder()
.setValue(joinParametersStr.toString)
.build()
com.google.protobuf.Any.newBuilder
.setValue(message.toByteString)
.setTypeUrl("/google.protobuf.StringValue")
BackendsApiManager.getTransformerApiInstance.getPackMessage(message)
}

// Direct output order of substrait join operation
Expand Down Expand Up @@ -235,7 +233,7 @@ case class SortMergeJoinExecTransformer(
substraitJoinType,
false,
joinType,
genJoinParametersBuilder(),
genJoinParameters(),
null,
null,
streamedPlan.output,
Expand Down Expand Up @@ -298,7 +296,7 @@ case class SortMergeJoinExecTransformer(
substraitJoinType,
false,
joinType,
genJoinParametersBuilder(),
genJoinParameters(),
inputStreamedRelNode,
inputBuildRelNode,
inputStreamedOutput,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ case class WindowExecTransformer(

override def outputPartitioning: Partitioning = child.outputPartitioning

def genWindowParametersBuilder(): com.google.protobuf.Any.Builder = {
def genWindowParameters(): Any = {
// Start with "WindowParameters:"
val windowParametersStr = new StringBuffer("WindowParameters:")
// isStreaming: 1 for streaming, 0 for sort
Expand All @@ -99,9 +99,7 @@ case class WindowExecTransformer(
.newBuilder()
.setValue(windowParametersStr.toString)
.build()
com.google.protobuf.Any.newBuilder
.setValue(message.toByteString)
.setTypeUrl("/google.protobuf.StringValue")
BackendsApiManager.getTransformerApiInstance.getPackMessage(message)
}

def getRelNode(
Expand Down Expand Up @@ -157,7 +155,7 @@ case class WindowExecTransformer(
}.asJava
if (!validation) {
val extensionNode =
ExtensionBuilder.makeAdvancedExtension(genWindowParametersBuilder.build(), null)
ExtensionBuilder.makeAdvancedExtension(genWindowParameters(), null)
RelBuilder.makeWindowRel(
input,
windowExpressions,
Expand Down

0 comments on commit 8518680

Please sign in to comment.