diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/CartesianProductExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/CartesianProductExecTransformer.scala index 0dd110fa542f2..28bf1eeabd233 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/CartesianProductExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/CartesianProductExecTransformer.scala @@ -86,12 +86,10 @@ case class CartesianProductExecTransformer( val (inputRightRelNode, inputRightOutput) = (rightPlanContext.root, rightPlanContext.outputAttributes) - val expressionNode = condition.map { - expr => - ExpressionConverter - .replaceWithExpressionTransformer(expr, inputLeftOutput ++ inputRightOutput) - .doTransform(context.registeredFunction) - } + val expressionNode = + condition.map { + SubstraitUtil.toSubstraitExpression(_, inputLeftOutput ++ inputRightOutput, context) + } val extensionNode = JoinUtils.createExtensionNode(inputLeftOutput ++ inputRightOutput, validation = false) diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/JoinUtils.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/JoinUtils.scala index 9dd73800e29bc..12d08518509a7 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/JoinUtils.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/JoinUtils.scala @@ -16,13 +16,12 @@ */ package org.apache.gluten.execution -import org.apache.gluten.backendsapi.BackendsApiManager -import org.apache.gluten.expression.{AttributeReferenceTransformer, ConverterUtils, ExpressionConverter} -import org.apache.gluten.substrait.`type`.TypeBuilder +import org.apache.gluten.expression.{AttributeReferenceTransformer, ExpressionConverter} import org.apache.gluten.substrait.SubstraitContext import org.apache.gluten.substrait.expression.{ExpressionBuilder, ExpressionNode} import org.apache.gluten.substrait.extensions.{AdvancedExtensionNode, ExtensionBuilder} import org.apache.gluten.substrait.rel.{RelBuilder, RelNode} +import org.apache.gluten.utils.SubstraitUtil import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} import org.apache.spark.sql.catalyst.plans._ @@ -34,21 +33,11 @@ import io.substrait.proto.{CrossRel, JoinRel} import scala.collection.JavaConverters._ object JoinUtils { - private def createEnhancement(output: Seq[Attribute]): com.google.protobuf.Any = { - val inputTypeNodes = output.map { - attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable) - } - // Normally the enhancement node is only used for plan validation. But here the enhancement - // is also used in execution phase. In this case an empty typeUrlPrefix need to be passed, - // so that it can be correctly parsed into json string on the cpp side. - BackendsApiManager.getTransformerApiInstance.packPBMessage( - TypeBuilder.makeStruct(false, inputTypeNodes.asJava).toProtobuf) - } def createExtensionNode(output: Seq[Attribute], validation: Boolean): AdvancedExtensionNode = { // Use field [enhancement] in a extension node for input type validation. if (validation) { - ExtensionBuilder.makeAdvancedExtension(createEnhancement(output)) + ExtensionBuilder.makeAdvancedExtension(SubstraitUtil.createEnhancement(output)) } else { null } @@ -58,7 +47,7 @@ object JoinUtils { !keyExprs.forall(_.isInstanceOf[AttributeReference]) } - def createPreProjectionIfNeeded( + private def createPreProjectionIfNeeded( keyExprs: Seq[Expression], inputNode: RelNode, inputNodeOutput: Seq[Attribute], @@ -131,17 +120,17 @@ object JoinUtils { } } - def createJoinExtensionNode( + private def createJoinExtensionNode( joinParameters: Any, output: Seq[Attribute]): AdvancedExtensionNode = { // Use field [optimization] in a extension node // to send some join parameters through Substrait plan. - val enhancement = createEnhancement(output) + val enhancement = SubstraitUtil.createEnhancement(output) ExtensionBuilder.makeAdvancedExtension(joinParameters, enhancement) } // Return the direct join output. - protected def getDirectJoinOutput( + private def getDirectJoinOutput( joinType: JoinType, leftOutput: Seq[Attribute], rightOutput: Seq[Attribute]): (Seq[Attribute], Seq[Attribute]) = { @@ -164,7 +153,7 @@ object JoinUtils { } } - protected def getDirectJoinOutputSeq( + private def getDirectJoinOutputSeq( joinType: JoinType, leftOutput: Seq[Attribute], rightOutput: Seq[Attribute]): Seq[Attribute] = { @@ -209,8 +198,8 @@ object JoinUtils { validation) // Combine join keys to make a single expression. - val joinExpressionNode = (streamedKeys - .zip(buildKeys)) + val joinExpressionNode = streamedKeys + .zip(buildKeys) .map { case ((leftKey, leftType), (rightKey, rightType)) => HashJoinLikeExecTransformer.makeEqualToExpression( @@ -225,12 +214,10 @@ object JoinUtils { HashJoinLikeExecTransformer.makeAndExpression(l, r, substraitContext.registeredFunction)) // Create post-join filter, which will be computed in hash join. - val postJoinFilter = condition.map { - expr => - ExpressionConverter - .replaceWithExpressionTransformer(expr, streamedOutput ++ buildOutput) - .doTransform(substraitContext.registeredFunction) - } + val postJoinFilter = + condition.map { + SubstraitUtil.toSubstraitExpression(_, streamedOutput ++ buildOutput, substraitContext) + } // Create JoinRel. val joinRel = RelBuilder.makeJoinRel( @@ -340,12 +327,14 @@ object JoinUtils { joinParameters: Any, validation: Boolean = false ): RelNode = { - val expressionNode = condition.map { - expr => - ExpressionConverter - .replaceWithExpressionTransformer(expr, inputStreamedOutput ++ inputBuildOutput) - .doTransform(substraitContext.registeredFunction) - } + val expressionNode = + condition.map { + SubstraitUtil.toSubstraitExpression( + _, + inputStreamedOutput ++ inputBuildOutput, + substraitContext) + } + val extensionNode = createJoinExtensionNode(joinParameters, inputStreamedOutput ++ inputBuildOutput) diff --git a/gluten-core/src/main/scala/org/apache/gluten/execution/WriteFilesExecTransformer.scala b/gluten-core/src/main/scala/org/apache/gluten/execution/WriteFilesExecTransformer.scala index d78f21beaabfe..d2ec994ba64b6 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/execution/WriteFilesExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/execution/WriteFilesExecTransformer.scala @@ -21,10 +21,11 @@ import org.apache.gluten.backendsapi.BackendsApiManager import org.apache.gluten.expression.ConverterUtils import org.apache.gluten.extension.ValidationResult import org.apache.gluten.metrics.MetricsUpdater -import org.apache.gluten.substrait.`type`.{ColumnTypeNode, TypeBuilder} +import org.apache.gluten.substrait.`type`.ColumnTypeNode import org.apache.gluten.substrait.SubstraitContext import org.apache.gluten.substrait.extensions.ExtensionBuilder import org.apache.gluten.substrait.rel.{RelBuilder, RelNode} +import org.apache.gluten.utils.SubstraitUtil import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec @@ -32,7 +33,9 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.datasources.FileFormat +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.types.MetadataBuilder import com.google.protobuf.{Any, StringValue} @@ -40,7 +43,6 @@ import org.apache.parquet.hadoop.ParquetOutputFormat import java.util.Locale -import scala.collection.JavaConverters._ import scala.collection.convert.ImplicitConversions.`collection AsScalaIterable` /** @@ -56,7 +58,7 @@ case class WriteFilesExecTransformer( staticPartitions: TablePartitionSpec) extends UnaryTransformSupport { // Note: "metrics" is made transient to avoid sending driver-side metrics to tasks. - @transient override lazy val metrics = + @transient override lazy val metrics: Map[String, SQLMetric] = BackendsApiManager.getMetricsApiInstance.genWriteFilesTransformerMetrics(sparkContext) override def metricsUpdater(): MetricsUpdater = @@ -66,11 +68,18 @@ case class WriteFilesExecTransformer( private val caseInsensitiveOptions = CaseInsensitiveMap(options) - def genWriteParameters(): Any = { + private def genWriteParameters(): Any = { + val fileFormatStr = fileFormat match { + case register: DataSourceRegister => + register.shortName + case _ => "UnknownFileFormat" + } val compressionCodec = WriteFilesExecTransformer.getCompressionCodec(caseInsensitiveOptions).capitalize val writeParametersStr = new StringBuffer("WriteParameters:") - writeParametersStr.append("is").append(compressionCodec).append("=1").append("\n") + writeParametersStr.append("is").append(compressionCodec).append("=1") + writeParametersStr.append(";format=").append(fileFormatStr).append("\n") + val message = StringValue .newBuilder() .setValue(writeParametersStr.toString) @@ -78,15 +87,6 @@ case class WriteFilesExecTransformer( BackendsApiManager.getTransformerApiInstance.packPBMessage(message) } - def createEnhancement(output: Seq[Attribute]): com.google.protobuf.Any = { - val inputTypeNodes = output.map { - attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable) - } - - BackendsApiManager.getTransformerApiInstance.packPBMessage( - TypeBuilder.makeStruct(false, inputTypeNodes.asJava).toProtobuf) - } - def getRelNode( context: SubstraitContext, originalInputAttributes: Seq[Attribute], @@ -118,10 +118,11 @@ case class WriteFilesExecTransformer( val extensionNode = if (!validation) { ExtensionBuilder.makeAdvancedExtension( genWriteParameters(), - createEnhancement(originalInputAttributes)) + SubstraitUtil.createEnhancement(originalInputAttributes)) } else { // Use a extension node to send the input types through Substrait plan for validation. - ExtensionBuilder.makeAdvancedExtension(createEnhancement(originalInputAttributes)) + ExtensionBuilder.makeAdvancedExtension( + SubstraitUtil.createEnhancement(originalInputAttributes)) } RelBuilder.makeWriteRel( input, @@ -133,7 +134,7 @@ case class WriteFilesExecTransformer( operatorId) } - private def getFinalChildOutput(): Seq[Attribute] = { + private def getFinalChildOutput: Seq[Attribute] = { val metadataExclusionList = conf .getConf(GlutenConfig.NATIVE_WRITE_FILES_COLUMN_METADATA_EXCLUSION_LIST) .split(",") @@ -143,7 +144,7 @@ case class WriteFilesExecTransformer( } override protected def doValidateInternal(): ValidationResult = { - val finalChildOutput = getFinalChildOutput() + val finalChildOutput = getFinalChildOutput val validationResult = BackendsApiManager.getSettings.supportWriteFilesExec( fileFormat, @@ -165,7 +166,7 @@ case class WriteFilesExecTransformer( val childCtx = child.asInstanceOf[TransformSupport].transform(context) val operatorId = context.nextOperatorId(this.nodeName) val currRel = - getRelNode(context, getFinalChildOutput(), operatorId, childCtx.root, validation = false) + getRelNode(context, getFinalChildOutput, operatorId, childCtx.root, validation = false) assert(currRel != null, "Write Rel should be valid") TransformContext(childCtx.outputAttributes, output, currRel) } @@ -196,7 +197,7 @@ object WriteFilesExecTransformer { "__file_source_generated_metadata_col" ) - def removeMetadata(attr: Attribute, metadataExclusionList: Seq[String]): Attribute = { + private def removeMetadata(attr: Attribute, metadataExclusionList: Seq[String]): Attribute = { val metadataKeys = INTERNAL_METADATA_KEYS ++ metadataExclusionList attr.withMetadata { var builder = new MetadataBuilder().withMetadata(attr.metadata) diff --git a/gluten-core/src/main/scala/org/apache/gluten/utils/SubstraitUtil.scala b/gluten-core/src/main/scala/org/apache/gluten/utils/SubstraitUtil.scala index e8e7ce06feaf4..c641cb44891da 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/utils/SubstraitUtil.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/utils/SubstraitUtil.scala @@ -16,10 +16,19 @@ */ package org.apache.gluten.utils +import org.apache.gluten.backendsapi.BackendsApiManager +import org.apache.gluten.expression.{ConverterUtils, ExpressionConverter} +import org.apache.gluten.substrait.`type`.TypeBuilder +import org.apache.gluten.substrait.SubstraitContext +import org.apache.gluten.substrait.expression.ExpressionNode + +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} import org.apache.spark.sql.catalyst.plans.{FullOuter, InnerLike, JoinType, LeftAnti, LeftOuter, LeftSemi, RightOuter} import io.substrait.proto.{CrossRel, JoinRel} +import scala.collection.JavaConverters._ + object SubstraitUtil { def toSubstrait(sparkJoin: JoinType): JoinRel.JoinType = sparkJoin match { case _: InnerLike => @@ -55,4 +64,24 @@ object SubstraitUtil { case _ => CrossRel.JoinType.UNRECOGNIZED } + + def createEnhancement(output: Seq[Attribute]): com.google.protobuf.Any = { + val inputTypeNodes = output.map { + attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable) + } + // Normally the enhancement node is only used for plan validation. But here the enhancement + // is also used in execution phase. In this case an empty typeUrlPrefix need to be passed, + // so that it can be correctly parsed into json string on the cpp side. + BackendsApiManager.getTransformerApiInstance.packPBMessage( + TypeBuilder.makeStruct(false, inputTypeNodes.asJava).toProtobuf) + } + + def toSubstraitExpression( + expr: Expression, + attributeSeq: Seq[Attribute], + context: SubstraitContext): ExpressionNode = { + ExpressionConverter + .replaceWithExpressionTransformer(expr, attributeSeq) + .doTransform(context.registeredFunction) + } }