Skip to content

Commit

Permalink
Support row type and fix subfield
Browse files Browse the repository at this point in the history
  • Loading branch information
rui-mo committed Jul 30, 2024
1 parent 4ae223c commit 484abf4
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DecimalType, IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.types.{ArrayType, DecimalType, IntegerType, StringType, StructField, StructType}

import java.util.concurrent.TimeUnit

Expand Down Expand Up @@ -102,6 +102,33 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla
"where l_comment is null") { _ => }
assert(df.isEmpty)
checkLengthAndPlan(df, 0)

// Struct of array.
val data =
Row(Row(Array("a", "b", "c"), null)) ::
Row(Row(Array("d", "e", "f"), Array(1, 2, 3))) ::
Row(Row(null, null)) :: Nil

val schema = new StructType()
.add(
"struct",
new StructType()
.add("a0", ArrayType(StringType))
.add("a1", ArrayType(IntegerType)))

val dataFrame = spark.createDataFrame(JavaConverters.seqAsJavaList(data), schema)

withTempPath {
path =>
dataFrame.write.parquet(path.getCanonicalPath)
spark.read.parquet(path.getCanonicalPath).createOrReplaceTempView("view")
runQueryAndCompare("select * from view where struct is null") {
checkGlutenOperatorMatch[FileSourceScanExecTransformer]
}
runQueryAndCompare("select * from view where struct.a0 is null") {
checkGlutenOperatorMatch[FileSourceScanExecTransformer]
}
}
}

test("is_null_has_null") {
Expand All @@ -119,6 +146,33 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla
"select l_orderkey from lineitem where l_comment is not null " +
"and l_orderkey = 1") { _ => }
checkLengthAndPlan(df, 6)

// Struct of array.
val data =
Row(Row(Array("a", "b", "c"), null)) ::
Row(Row(Array("d", "e", "f"), Array(1, 2, 3))) ::
Row(Row(null, null)) :: Nil

val schema = new StructType()
.add(
"struct",
new StructType()
.add("a0", ArrayType(StringType))
.add("a1", ArrayType(IntegerType)))

val dataFrame = spark.createDataFrame(JavaConverters.seqAsJavaList(data), schema)

withTempPath {
path =>
dataFrame.write.parquet(path.getCanonicalPath)
spark.read.parquet(path.getCanonicalPath).createOrReplaceTempView("view")
runQueryAndCompare("select * from view where struct is not null") {
checkGlutenOperatorMatch[FileSourceScanExecTransformer]
}
runQueryAndCompare("select * from view where struct.a0 is not null") {
checkGlutenOperatorMatch[FileSourceScanExecTransformer]
}
}
}

test("is_null and is_not_null coexist") {
Expand Down
14 changes: 12 additions & 2 deletions cpp/velox/substrait/SubstraitParser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,21 @@ void SubstraitParser::parseColumnTypes(
return;
}

int32_t SubstraitParser::parseReferenceSegment(const ::substrait::Expression::ReferenceSegment& refSegment) {
bool SubstraitParser::parseReferenceSegment(
const ::substrait::Expression::ReferenceSegment& refSegment,
uint32_t& fieldIndex) {
auto typeCase = refSegment.reference_type_case();
switch (typeCase) {
case ::substrait::Expression::ReferenceSegment::ReferenceTypeCase::kStructField: {
return refSegment.struct_field().field();
if (refSegment.struct_field().has_child()) {
// To parse subfield index is not supported.
return false;
}
fieldIndex = refSegment.struct_field().field();
if (fieldIndex < 0) {
return false;
}
return true;
}
default:
VELOX_NYI("Substrait conversion not supported for ReferenceSegment '{}'", std::to_string(typeCase));
Expand Down
5 changes: 3 additions & 2 deletions cpp/velox/substrait/SubstraitParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ class SubstraitParser {
/// Parse Substrait Type to Velox type.
static facebook::velox::TypePtr parseType(const ::substrait::Type& substraitType, bool asLowerCase = false);

/// Parse Substrait ReferenceSegment.
static int32_t parseReferenceSegment(const ::substrait::Expression::ReferenceSegment& refSegment);
/// Parse Substrait ReferenceSegment and extract the field index. Return false if the segment is not a valid unnested
/// field.
static bool parseReferenceSegment(const ::substrait::Expression::ReferenceSegment& refSegment, uint32_t& fieldIndex);

/// Make names in the format of {prefix}_{index}.
static std::vector<std::string> makeNames(const std::string& prefix, int size);
Expand Down
62 changes: 40 additions & 22 deletions cpp/velox/substrait/SubstraitToVeloxPlan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1530,7 +1530,9 @@ bool SubstraitToVeloxPlanConverter::fieldOrWithLiteral(
if (arguments.size() == 1) {
if (arguments[0].value().has_selection()) {
// Only field exists.
fieldIndex = SubstraitParser::parseReferenceSegment(arguments[0].value().selection().direct_reference());
if (!SubstraitParser::parseReferenceSegment(arguments[0].value().selection().direct_reference(), fieldIndex)) {
return false;
}
return true;
} else {
return false;
Expand All @@ -1547,7 +1549,9 @@ bool SubstraitToVeloxPlanConverter::fieldOrWithLiteral(
auto typeCase = param.value().rex_type_case();
switch (typeCase) {
case ::substrait::Expression::RexTypeCase::kSelection:
fieldIndex = SubstraitParser::parseReferenceSegment(param.value().selection().direct_reference());
if (!SubstraitParser::parseReferenceSegment(param.value().selection().direct_reference(), fieldIndex)) {
return false;
}
fieldExists = true;
break;
case ::substrait::Expression::RexTypeCase::kLiteral:
Expand All @@ -1564,22 +1568,24 @@ bool SubstraitToVeloxPlanConverter::fieldOrWithLiteral(
bool SubstraitToVeloxPlanConverter::childrenFunctionsOnSameField(
const ::substrait::Expression_ScalarFunction& function) {
// Get the column indices of the children functions.
std::vector<int32_t> colIndices;
std::vector<uint32_t> colIndices;
for (const auto& arg : function.arguments()) {
if (arg.value().has_scalar_function()) {
const auto& scalarFunction = arg.value().scalar_function();
for (const auto& param : scalarFunction.arguments()) {
if (param.value().has_selection()) {
const auto& field = param.value().selection();
VELOX_CHECK(field.has_direct_reference());
int32_t colIdx = SubstraitParser::parseReferenceSegment(field.direct_reference());
uint32_t colIdx;
if (!SubstraitParser::parseReferenceSegment(field.direct_reference(), colIdx)) {
return false;
}
colIndices.emplace_back(colIdx);
}
}
} else if (arg.value().has_singular_or_list()) {
const auto& singularOrList = arg.value().singular_or_list();
int32_t colIdx = getColumnIndexFromSingularOrList(singularOrList);
colIndices.emplace_back(colIdx);
colIndices.emplace_back(getColumnIndexFromSingularOrList(singularOrList));
} else {
return false;
}
Expand Down Expand Up @@ -1711,8 +1717,9 @@ void SubstraitToVeloxPlanConverter::separateFilters(
if (format == dwio::common::FileFormat::ORC && scalarFunction.arguments().size() > 0) {
auto value = scalarFunction.arguments().at(0).value();
if (value.has_selection()) {
uint32_t fieldIndex = SubstraitParser::parseReferenceSegment(value.selection().direct_reference());
if (!veloxTypeList.empty() && veloxTypeList.at(fieldIndex)->isDecimal()) {
uint32_t fieldIndex;
bool parsed = SubstraitParser::parseReferenceSegment(value.selection().direct_reference(), fieldIndex);
if (!parsed || (!veloxTypeList.empty() && veloxTypeList.at(fieldIndex)->isDecimal())) {
remainingFunctions.emplace_back(scalarFunction);
continue;
}
Expand Down Expand Up @@ -1872,7 +1879,10 @@ void SubstraitToVeloxPlanConverter::setFilterInfo(
switch (typeCase) {
case ::substrait::Expression::RexTypeCase::kSelection:
typeCases.emplace_back("kSelection");
colIdx = SubstraitParser::parseReferenceSegment(param.value().selection().direct_reference());
uint32_t index;
bool parsed = SubstraitParser::parseReferenceSegment(param.value().selection().direct_reference(), index);
VELOX_CHECK(parsed, "Failed to parse the column index from the selection.");
colIdx = index;
break;
case ::substrait::Expression::RexTypeCase::kLiteral:
typeCases.emplace_back("kLiteral");
Expand Down Expand Up @@ -2177,18 +2187,17 @@ void SubstraitToVeloxPlanConverter::constructSubfieldFilters(
VELOX_CHECK(value == filterInfo.upperBounds_[0].value().value<bool>(), "invalid state of bool equal");
filters[common::Subfield(inputName)] = std::make_unique<common::BoolValue>(value, nullAllowed);
}
} else if constexpr (KIND == facebook::velox::TypeKind::ARRAY || KIND == facebook::velox::TypeKind::MAP) {
// Only IsNotNull and IsNull are supported for array and map types.
if (rangeSize == 0) {
if (!nullAllowed) {
filters[common::Subfield(inputName)] = std::make_unique<common::IsNotNull>();
} else if (isNull) {
filters[common::Subfield(inputName)] = std::make_unique<common::IsNull>();
} else {
VELOX_NYI(
"Only IsNotNull and IsNull are supported in constructSubfieldFilters for input type '{}'.",
inputType->toString());
}
} else if constexpr (
KIND == facebook::velox::TypeKind::ARRAY || KIND == facebook::velox::TypeKind::MAP ||
KIND == facebook::velox::TypeKind::ROW) {
// Only IsNotNull and IsNull are supported for complex types.
VELOX_CHECK_EQ(rangeSize, 0, "Only IsNotNull and IsNull are supported for complex type.");
if (!nullAllowed) {
filters[common::Subfield(inputName)] = std::make_unique<common::IsNotNull>();
} else if (isNull) {
filters[common::Subfield(inputName)] = std::make_unique<common::IsNull>();
} else {
VELOX_NYI("Only IsNotNull and IsNull are supported for input type '{}'.", inputType->toString());
}
} else {
using NativeType = typename RangeTraits<KIND>::NativeType;
Expand Down Expand Up @@ -2393,6 +2402,10 @@ connector::hive::SubfieldFilters SubstraitToVeloxPlanConverter::mapToFilters(
constructSubfieldFilters<TypeKind::MAP, common::Filter>(
colIdx, inputNameList[colIdx], inputType, columnToFilterInfo[colIdx], filters);
break;
case TypeKind::ROW:
constructSubfieldFilters<TypeKind::ROW, common::Filter>(
colIdx, inputNameList[colIdx], inputType, columnToFilterInfo[colIdx], filters);
break;
default:
VELOX_NYI(
"Subfield filters creation not supported for input type '{}' in mapToFilters", inputType->toString());
Expand Down Expand Up @@ -2494,7 +2507,12 @@ uint32_t SubstraitToVeloxPlanConverter::getColumnIndexFromSingularOrList(
} else {
VELOX_FAIL("Unsupported type in IN pushdown.");
}
return SubstraitParser::parseReferenceSegment(selection.direct_reference());
uint32_t index;
VELOX_CHECK(
SubstraitParser::parseReferenceSegment(selection.direct_reference(), index),
"Failed to parse column index from SingularOrList."
);
return index;
}

void SubstraitToVeloxPlanConverter::setFilterInfo(
Expand Down

0 comments on commit 484abf4

Please sign in to comment.