From d5b685ac7f95519c589aafbb0ebcfc61fc0f4c42 Mon Sep 17 00:00:00 2001 From: rui-mo Date: Mon, 29 Jul 2024 09:26:27 +0800 Subject: [PATCH] Support row type and fix subfield --- .../gluten/execution/TestOperator.scala | 56 +++++++++++++- cpp/velox/substrait/SubstraitParser.cc | 14 +++- cpp/velox/substrait/SubstraitParser.h | 5 +- cpp/velox/substrait/SubstraitToVeloxPlan.cc | 73 ++++++++++++------- 4 files changed, 116 insertions(+), 32 deletions(-) diff --git a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala index dcae4920d01c..0fb5fb54900b 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/execution/TestOperator.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DecimalType, IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, DecimalType, IntegerType, StringType, StructField, StructType} import java.util.concurrent.TimeUnit @@ -102,6 +102,33 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla "where l_comment is null") { _ => } assert(df.isEmpty) checkLengthAndPlan(df, 0) + + // Struct of array. + val data = + Row(Row(Array("a", "b", "c"), null)) :: + Row(Row(Array("d", "e", "f"), Array(1, 2, 3))) :: + Row(Row(null, null)) :: Nil + + val schema = new StructType() + .add( + "struct", + new StructType() + .add("a0", ArrayType(StringType)) + .add("a1", ArrayType(IntegerType))) + + val dataFrame = spark.createDataFrame(JavaConverters.seqAsJavaList(data), schema) + + withTempPath { + path => + dataFrame.write.parquet(path.getCanonicalPath) + spark.read.parquet(path.getCanonicalPath).createOrReplaceTempView("view") + runQueryAndCompare("select * from view where struct is null") { + checkGlutenOperatorMatch[FileSourceScanExecTransformer] + } + runQueryAndCompare("select * from view where struct.a0 is null") { + checkGlutenOperatorMatch[FileSourceScanExecTransformer] + } + } } test("is_null_has_null") { @@ -119,6 +146,33 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla "select l_orderkey from lineitem where l_comment is not null " + "and l_orderkey = 1") { _ => } checkLengthAndPlan(df, 6) + + // Struct of array. + val data = + Row(Row(Array("a", "b", "c"), null)) :: + Row(Row(Array("d", "e", "f"), Array(1, 2, 3))) :: + Row(Row(null, null)) :: Nil + + val schema = new StructType() + .add( + "struct", + new StructType() + .add("a0", ArrayType(StringType)) + .add("a1", ArrayType(IntegerType))) + + val dataFrame = spark.createDataFrame(JavaConverters.seqAsJavaList(data), schema) + + withTempPath { + path => + dataFrame.write.parquet(path.getCanonicalPath) + spark.read.parquet(path.getCanonicalPath).createOrReplaceTempView("view") + runQueryAndCompare("select * from view where struct is not null") { + checkGlutenOperatorMatch[FileSourceScanExecTransformer] + } + runQueryAndCompare("select * from view where struct.a0 is not null") { + checkGlutenOperatorMatch[FileSourceScanExecTransformer] + } + } } test("is_null and is_not_null coexist") { diff --git a/cpp/velox/substrait/SubstraitParser.cc b/cpp/velox/substrait/SubstraitParser.cc index b842914ca933..793f68184daa 100644 --- a/cpp/velox/substrait/SubstraitParser.cc +++ b/cpp/velox/substrait/SubstraitParser.cc @@ -141,11 +141,21 @@ void SubstraitParser::parseColumnTypes( return; } -int32_t SubstraitParser::parseReferenceSegment(const ::substrait::Expression::ReferenceSegment& refSegment) { +bool SubstraitParser::parseReferenceSegment( + const ::substrait::Expression::ReferenceSegment& refSegment, + uint32_t& fieldIndex) { auto typeCase = refSegment.reference_type_case(); switch (typeCase) { case ::substrait::Expression::ReferenceSegment::ReferenceTypeCase::kStructField: { - return refSegment.struct_field().field(); + if (refSegment.struct_field().has_child()) { + // To parse subfield index is not supported. + return false; + } + fieldIndex = refSegment.struct_field().field(); + if (fieldIndex < 0) { + return false; + } + return true; } default: VELOX_NYI("Substrait conversion not supported for ReferenceSegment '{}'", std::to_string(typeCase)); diff --git a/cpp/velox/substrait/SubstraitParser.h b/cpp/velox/substrait/SubstraitParser.h index 1f766b91ca1b..f42d05b4a21c 100644 --- a/cpp/velox/substrait/SubstraitParser.h +++ b/cpp/velox/substrait/SubstraitParser.h @@ -50,8 +50,9 @@ class SubstraitParser { /// Parse Substrait Type to Velox type. static facebook::velox::TypePtr parseType(const ::substrait::Type& substraitType, bool asLowerCase = false); - /// Parse Substrait ReferenceSegment. - static int32_t parseReferenceSegment(const ::substrait::Expression::ReferenceSegment& refSegment); + /// Parse Substrait ReferenceSegment and extract the field index. Return false if the segment is not a valid unnested + /// field. + static bool parseReferenceSegment(const ::substrait::Expression::ReferenceSegment& refSegment, uint32_t& fieldIndex); /// Make names in the format of {prefix}_{index}. static std::vector makeNames(const std::string& prefix, int size); diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.cc b/cpp/velox/substrait/SubstraitToVeloxPlan.cc index 7b41f7071e84..d7de841191ed 100644 --- a/cpp/velox/substrait/SubstraitToVeloxPlan.cc +++ b/cpp/velox/substrait/SubstraitToVeloxPlan.cc @@ -1530,8 +1530,7 @@ bool SubstraitToVeloxPlanConverter::fieldOrWithLiteral( if (arguments.size() == 1) { if (arguments[0].value().has_selection()) { // Only field exists. - fieldIndex = SubstraitParser::parseReferenceSegment(arguments[0].value().selection().direct_reference()); - return true; + return SubstraitParser::parseReferenceSegment(arguments[0].value().selection().direct_reference(), fieldIndex); } else { return false; } @@ -1546,13 +1545,17 @@ bool SubstraitToVeloxPlanConverter::fieldOrWithLiteral( for (const auto& param : arguments) { auto typeCase = param.value().rex_type_case(); switch (typeCase) { - case ::substrait::Expression::RexTypeCase::kSelection: - fieldIndex = SubstraitParser::parseReferenceSegment(param.value().selection().direct_reference()); + case ::substrait::Expression::RexTypeCase::kSelection: { + if (!SubstraitParser::parseReferenceSegment(param.value().selection().direct_reference(), fieldIndex)) { + return false; + } fieldExists = true; break; - case ::substrait::Expression::RexTypeCase::kLiteral: + } + case ::substrait::Expression::RexTypeCase::kLiteral: { literalExists = true; break; + } default: break; } @@ -1564,7 +1567,7 @@ bool SubstraitToVeloxPlanConverter::fieldOrWithLiteral( bool SubstraitToVeloxPlanConverter::childrenFunctionsOnSameField( const ::substrait::Expression_ScalarFunction& function) { // Get the column indices of the children functions. - std::vector colIndices; + std::vector colIndices; for (const auto& arg : function.arguments()) { if (arg.value().has_scalar_function()) { const auto& scalarFunction = arg.value().scalar_function(); @@ -1572,14 +1575,16 @@ bool SubstraitToVeloxPlanConverter::childrenFunctionsOnSameField( if (param.value().has_selection()) { const auto& field = param.value().selection(); VELOX_CHECK(field.has_direct_reference()); - int32_t colIdx = SubstraitParser::parseReferenceSegment(field.direct_reference()); + uint32_t colIdx; + if (!SubstraitParser::parseReferenceSegment(field.direct_reference(), colIdx)) { + return false; + } colIndices.emplace_back(colIdx); } } } else if (arg.value().has_singular_or_list()) { const auto& singularOrList = arg.value().singular_or_list(); - int32_t colIdx = getColumnIndexFromSingularOrList(singularOrList); - colIndices.emplace_back(colIdx); + colIndices.emplace_back(getColumnIndexFromSingularOrList(singularOrList)); } else { return false; } @@ -1711,8 +1716,9 @@ void SubstraitToVeloxPlanConverter::separateFilters( if (format == dwio::common::FileFormat::ORC && scalarFunction.arguments().size() > 0) { auto value = scalarFunction.arguments().at(0).value(); if (value.has_selection()) { - uint32_t fieldIndex = SubstraitParser::parseReferenceSegment(value.selection().direct_reference()); - if (!veloxTypeList.empty() && veloxTypeList.at(fieldIndex)->isDecimal()) { + uint32_t fieldIndex; + bool parsed = SubstraitParser::parseReferenceSegment(value.selection().direct_reference(), fieldIndex); + if (!parsed || (!veloxTypeList.empty() && veloxTypeList.at(fieldIndex)->isDecimal())) { remainingFunctions.emplace_back(scalarFunction); continue; } @@ -1870,14 +1876,20 @@ void SubstraitToVeloxPlanConverter::setFilterInfo( for (const auto& param : scalarFunction.arguments()) { auto typeCase = param.value().rex_type_case(); switch (typeCase) { - case ::substrait::Expression::RexTypeCase::kSelection: + case ::substrait::Expression::RexTypeCase::kSelection: { typeCases.emplace_back("kSelection"); - colIdx = SubstraitParser::parseReferenceSegment(param.value().selection().direct_reference()); + uint32_t index; + VELOX_CHECK( + SubstraitParser::parseReferenceSegment(param.value().selection().direct_reference(), index), + "Failed to parse the column index from the selection."); + colIdx = index; break; - case ::substrait::Expression::RexTypeCase::kLiteral: + } + case ::substrait::Expression::RexTypeCase::kLiteral: { typeCases.emplace_back("kLiteral"); substraitLit = param.value().literal(); break; + } default: VELOX_NYI("Substrait conversion not supported for arg type '{}'", std::to_string(typeCase)); } @@ -2177,18 +2189,17 @@ void SubstraitToVeloxPlanConverter::constructSubfieldFilters( VELOX_CHECK(value == filterInfo.upperBounds_[0].value().value(), "invalid state of bool equal"); filters[common::Subfield(inputName)] = std::make_unique(value, nullAllowed); } - } else if constexpr (KIND == facebook::velox::TypeKind::ARRAY || KIND == facebook::velox::TypeKind::MAP) { - // Only IsNotNull and IsNull are supported for array and map types. - if (rangeSize == 0) { - if (!nullAllowed) { - filters[common::Subfield(inputName)] = std::make_unique(); - } else if (isNull) { - filters[common::Subfield(inputName)] = std::make_unique(); - } else { - VELOX_NYI( - "Only IsNotNull and IsNull are supported in constructSubfieldFilters for input type '{}'.", - inputType->toString()); - } + } else if constexpr ( + KIND == facebook::velox::TypeKind::ARRAY || KIND == facebook::velox::TypeKind::MAP || + KIND == facebook::velox::TypeKind::ROW) { + // Only IsNotNull and IsNull are supported for complex types. + VELOX_CHECK_EQ(rangeSize, 0, "Only IsNotNull and IsNull are supported for complex type."); + if (!nullAllowed) { + filters[common::Subfield(inputName)] = std::make_unique(); + } else if (isNull) { + filters[common::Subfield(inputName)] = std::make_unique(); + } else { + VELOX_NYI("Only IsNotNull and IsNull are supported for input type '{}'.", inputType->toString()); } } else { using NativeType = typename RangeTraits::NativeType; @@ -2393,6 +2404,10 @@ connector::hive::SubfieldFilters SubstraitToVeloxPlanConverter::mapToFilters( constructSubfieldFilters( colIdx, inputNameList[colIdx], inputType, columnToFilterInfo[colIdx], filters); break; + case TypeKind::ROW: + constructSubfieldFilters( + colIdx, inputNameList[colIdx], inputType, columnToFilterInfo[colIdx], filters); + break; default: VELOX_NYI( "Subfield filters creation not supported for input type '{}' in mapToFilters", inputType->toString()); @@ -2494,7 +2509,11 @@ uint32_t SubstraitToVeloxPlanConverter::getColumnIndexFromSingularOrList( } else { VELOX_FAIL("Unsupported type in IN pushdown."); } - return SubstraitParser::parseReferenceSegment(selection.direct_reference()); + uint32_t index; + VELOX_CHECK( + SubstraitParser::parseReferenceSegment(selection.direct_reference(), index), + "Failed to parse column index from SingularOrList."); + return index; } void SubstraitToVeloxPlanConverter::setFilterInfo(