diff --git a/cpp/velox/CMakeLists.txt b/cpp/velox/CMakeLists.txt index fdc2ec50c9e8..05c2811077bf 100644 --- a/cpp/velox/CMakeLists.txt +++ b/cpp/velox/CMakeLists.txt @@ -252,7 +252,6 @@ set(VELOX_SRCS substrait/SubstraitToVeloxPlan.cc substrait/SubstraitToVeloxPlanValidator.cc substrait/VariantToVectorConverter.cc - substrait/TypeUtils.cc substrait/SubstraitExtensionCollector.cc substrait/VeloxSubstraitSignature.cc substrait/VeloxToSubstraitExpr.cc diff --git a/cpp/velox/compute/VeloxPlanConverter.cc b/cpp/velox/compute/VeloxPlanConverter.cc index f04ca6b6c078..3766b5af2299 100644 --- a/cpp/velox/compute/VeloxPlanConverter.cc +++ b/cpp/velox/compute/VeloxPlanConverter.cc @@ -124,13 +124,13 @@ void VeloxPlanConverter::setInputPlanNode(const ::substrait::ReadRel& sread) { // Get the input schema of this iterator. uint64_t colNum = 0; - std::vector> subTypeList; + std::vector veloxTypeList; if (sread.has_base_schema()) { const auto& baseSchema = sread.base_schema(); // Input names is not used. Instead, new input/output names will be created // because the ValueStreamNode in Velox does not support name change. colNum = baseSchema.names().size(); - subTypeList = SubstraitParser::parseNamedStruct(baseSchema); + veloxTypeList = SubstraitParser::parseNamedStruct(baseSchema); } std::vector outNames; @@ -140,10 +140,6 @@ void VeloxPlanConverter::setInputPlanNode(const ::substrait::ReadRel& sread) { outNames.emplace_back(colName); } - std::vector veloxTypeList; - for (auto subType : subTypeList) { - veloxTypeList.push_back(toVeloxType(subType->type)); - } auto outputType = ROW(std::move(outNames), std::move(veloxTypeList)); auto vectorStream = std::make_shared(pool_, std::move(inputIters_[iterIdx]), outputType); auto valuesNode = std::make_shared(nextPlanNodeId(), outputType, std::move(vectorStream)); diff --git a/cpp/velox/substrait/SubstraitParser.cc b/cpp/velox/substrait/SubstraitParser.cc index 41b2eab7e91b..2de7012f3bdc 100644 --- a/cpp/velox/substrait/SubstraitParser.cc +++ b/cpp/velox/substrait/SubstraitParser.cc @@ -19,168 +19,87 @@ #include "TypeUtils.h" #include "velox/common/base/Exceptions.h" +#include "VeloxSubstraitSignature.h" + namespace gluten { -SubstraitParser::SubstraitType SubstraitParser::parseType(const ::substrait::Type& substraitType) { - // The used type names should be aligned with those in Velox. - std::string typeName; - ::substrait::Type_Nullability nullability; +TypePtr SubstraitParser::parseType(const ::substrait::Type& substraitType, bool asLowerCase) { switch (substraitType.kind_case()) { - case ::substrait::Type::KindCase::kBool: { - typeName = "BOOLEAN"; - nullability = substraitType.bool_().nullability(); - break; - } - case ::substrait::Type::KindCase::kI8: { - typeName = "TINYINT"; - nullability = substraitType.i8().nullability(); - break; - } - case ::substrait::Type::KindCase::kI16: { - typeName = "SMALLINT"; - nullability = substraitType.i16().nullability(); - break; - } - case ::substrait::Type::KindCase::kI32: { - typeName = "INTEGER"; - nullability = substraitType.i32().nullability(); - break; - } - case ::substrait::Type::KindCase::kI64: { - typeName = "BIGINT"; - nullability = substraitType.i64().nullability(); - break; - } - case ::substrait::Type::KindCase::kFp32: { - typeName = "REAL"; - nullability = substraitType.fp32().nullability(); - break; - } - case ::substrait::Type::KindCase::kFp64: { - typeName = "DOUBLE"; - nullability = substraitType.fp64().nullability(); - break; - } - case ::substrait::Type::KindCase::kString: { - typeName = "VARCHAR"; - nullability = substraitType.string().nullability(); - break; - } - case ::substrait::Type::KindCase::kBinary: { - typeName = "VARBINARY"; - nullability = substraitType.string().nullability(); - break; - } + case ::substrait::Type::KindCase::kBool: + return BOOLEAN(); + case ::substrait::Type::KindCase::kI8: + return TINYINT(); + case ::substrait::Type::KindCase::kI16: + return SMALLINT(); + case ::substrait::Type::KindCase::kI32: + return INTEGER(); + case ::substrait::Type::KindCase::kI64: + return BIGINT(); + case ::substrait::Type::KindCase::kFp32: + return REAL(); + case ::substrait::Type::KindCase::kFp64: + return DOUBLE(); + case ::substrait::Type::KindCase::kString: + return VARCHAR(); + case ::substrait::Type::KindCase::kBinary: + return VARBINARY(); case ::substrait::Type::KindCase::kStruct: { - // The type name of struct is in the format of: - // ROW,...typen:namen>. - typeName = "ROW<"; const auto& substraitStruct = substraitType.struct_(); const auto& structTypes = substraitStruct.types(); const auto& structNames = substraitStruct.names(); bool nameProvided = structTypes.size() == structNames.size(); + std::vector types; + std::vector names; for (int i = 0; i < structTypes.size(); i++) { - if (i > 0) { - typeName += ','; - } - typeName += parseType(structTypes[i]).type; - // Struct names could be empty. - if (nameProvided) { - typeName += (':' + structNames[i]); + types.emplace_back(parseType(structTypes[i])); + std::string fieldName = nameProvided ? structNames[i] : "col_" + std::to_string(i); + if (asLowerCase) { + folly::toLowerAscii(fieldName); } + names.emplace_back(fieldName); } - typeName += '>'; - nullability = substraitType.struct_().nullability(); - break; + return ROW(std::move(names), std::move(types)); } case ::substrait::Type::KindCase::kList: { - // The type name of list is in the format of: ARRAY. - const auto& sList = substraitType.list(); - const auto& sType = sList.type(); - typeName = "ARRAY<" + parseType(sType).type + ">"; - nullability = substraitType.list().nullability(); - break; + const auto& fieldType = substraitType.list().type(); + return ARRAY(parseType(fieldType)); } case ::substrait::Type::KindCase::kMap: { - // The type name of map is in the format of: MAP. const auto& sMap = substraitType.map(); const auto& keyType = sMap.key(); const auto& valueType = sMap.value(); - typeName = "MAP<" + parseType(keyType).type + "," + parseType(valueType).type + ">"; - nullability = substraitType.map().nullability(); - break; + return MAP(parseType(keyType), parseType(valueType)); } - case ::substrait::Type::KindCase::kUserDefined: { + case ::substrait::Type::KindCase::kUserDefined: // We only support UNKNOWN type to handle the null literal whose type is // not known. - VELOX_CHECK_EQ(substraitType.user_defined().type_reference(), 0); - typeName = "UNKNOWN"; - nullability = substraitType.string().nullability(); - break; - } - case ::substrait::Type::KindCase::kDate: { - typeName = "DATE"; - nullability = substraitType.date().nullability(); - break; - } - case ::substrait::Type::KindCase::kTimestamp: { - typeName = "TIMESTAMP"; - nullability = substraitType.timestamp().nullability(); - break; - } + return UNKNOWN(); + case ::substrait::Type::KindCase::kDate: + return DATE(); + case ::substrait::Type::KindCase::kTimestamp: + return TIMESTAMP(); case ::substrait::Type::KindCase::kDecimal: { auto precision = substraitType.decimal().precision(); auto scale = substraitType.decimal().scale(); - if (precision <= 18) { - typeName = "SHORT_DECIMAL<" + std::to_string(precision) + "," + std::to_string(scale) + ">"; - } else { - typeName = "HUGEINT<" + std::to_string(precision) + "," + std::to_string(scale) + ">"; - } - nullability = substraitType.decimal().nullability(); - break; + return DECIMAL(precision, scale); } default: VELOX_NYI("Parsing for Substrait type not supported: {}", substraitType.DebugString()); } - - bool nullable; - switch (nullability) { - case ::substrait::Type_Nullability::Type_Nullability_NULLABILITY_UNSPECIFIED: - nullable = true; - break; - case ::substrait::Type_Nullability::Type_Nullability_NULLABILITY_NULLABLE: - nullable = true; - break; - case ::substrait::Type_Nullability::Type_Nullability_NULLABILITY_REQUIRED: - nullable = false; - break; - default: - VELOX_NYI("Substrait parsing for nullability {} not supported.", nullability); - } - return SubstraitType{typeName, nullable}; } -std::string SubstraitParser::parseType(const std::string& substraitType) { - auto it = typeMap_.find(substraitType); - if (it == typeMap_.end()) { - VELOX_NYI("Substrait parsing for type {} not supported.", substraitType); - } - return it->second; -}; - -std::vector> SubstraitParser::parseNamedStruct( - const ::substrait::NamedStruct& namedStruct) { +std::vector SubstraitParser::parseNamedStruct(const ::substrait::NamedStruct& namedStruct, bool asLowerCase) { // Note that "names" are not used. // Parse Struct. const auto& substraitStruct = namedStruct.struct_(); const auto& substraitTypes = substraitStruct.types(); - std::vector> substraitTypeList; - substraitTypeList.reserve(substraitTypes.size()); + std::vector typeList; + typeList.reserve(substraitTypes.size()); for (const auto& type : substraitTypes) { - substraitTypeList.emplace_back(std::make_shared(parseType(type))); + typeList.emplace_back(parseType(type, asLowerCase)); } - return substraitTypeList; + return typeList; } std::vector SubstraitParser::parsePartitionColumns(const ::substrait::NamedStruct& namedStruct) { @@ -262,55 +181,57 @@ std::string SubstraitParser::findFunctionSpec( return x->second; } -std::string SubstraitParser::getSubFunctionName(const std::string& subFuncSpec) { - // Get the position of ":" in the function name. - std::size_t pos = subFuncSpec.find(":"); +// TODO Refactor using Bison. +std::string SubstraitParser::getNameBeforeDelimiter(const std::string& signature, const std::string& delimiter) { + std::size_t pos = signature.find(delimiter); if (pos == std::string::npos) { - return subFuncSpec; + return signature; } - return subFuncSpec.substr(0, pos); + return signature.substr(0, pos); } -void SubstraitParser::getSubFunctionTypes(const std::string& subFuncSpec, std::vector& types) { +std::vector SubstraitParser::getSubFunctionTypes(const std::string& substraitFunction) { // Get the position of ":" in the function name. - std::size_t pos = subFuncSpec.find(":"); + size_t pos = substraitFunction.find(":"); // Get the parameter types. - std::string funcTypes; - if (pos == std::string::npos) { - funcTypes = subFuncSpec; - } else { - if (pos == subFuncSpec.size() - 1) { - return; - } - funcTypes = subFuncSpec.substr(pos + 1); + std::vector types; + if (pos == std::string::npos || pos == substraitFunction.size() - 1) { + return types; } - // Split the types with delimiter. - std::string delimiter = "_"; - while ((pos = funcTypes.find(delimiter)) != std::string::npos) { - auto type = funcTypes.substr(0, pos); - if (type != "opt" && type != "req") { - types.emplace_back(type); + // Extract input types with delimiter. + for (;;) { + const size_t endPos = substraitFunction.find("_", pos + 1); + if (endPos == std::string::npos) { + std::string typeName = substraitFunction.substr(pos + 1); + if (typeName != "opt" && typeName != "req") { + types.emplace_back(typeName); + } + break; } - funcTypes.erase(0, pos + delimiter.length()); + + const std::string typeName = substraitFunction.substr(pos + 1, endPos - pos - 1); + if (typeName != "opt" && typeName != "req") { + types.emplace_back(typeName); + } + pos = endPos; } - types.emplace_back(funcTypes); + return types; } std::string SubstraitParser::findVeloxFunction( const std::unordered_map& functionMap, uint64_t id) { std::string funcSpec = findFunctionSpec(functionMap, id); - std::string_view funcName = getNameBeforeDelimiter(funcSpec, ":"); - std::vector types; - getSubFunctionTypes(funcSpec, types); + std::string funcName = getNameBeforeDelimiter(funcSpec); + std::vector types = getSubFunctionTypes(funcSpec); bool isDecimal = false; - for (auto& type : types) { + for (const auto& type : types) { if (type.find("dec") != std::string::npos) { isDecimal = true; break; } } - return mapToVeloxFunction({funcName.begin(), funcName.end()}, isDecimal); + return mapToVeloxFunction(funcName, isDecimal); } std::string SubstraitParser::mapToVeloxFunction(const std::string& substraitFunction, bool isDecimal) { @@ -347,6 +268,16 @@ bool SubstraitParser::configSetInOptimization( return false; } +std::vector SubstraitParser::sigToTypes(const std::string& signature) { + std::vector typeStrs = SubstraitParser::getSubFunctionTypes(signature); + std::vector types; + types.reserve(typeStrs.size()); + for (const auto& typeStr : typeStrs) { + types.emplace_back(VeloxSubstraitSignature::fromSubstraitSignature(typeStr)); + } + return types; +} + std::unordered_map SubstraitParser::substraitVeloxFunctionMap_ = { {"is_not_null", "isnotnull"}, /*Spark functions.*/ {"is_null", "isnull"}, diff --git a/cpp/velox/substrait/SubstraitParser.h b/cpp/velox/substrait/SubstraitParser.h index 8bb338ba7bff..fb39d65933af 100644 --- a/cpp/velox/substrait/SubstraitParser.h +++ b/cpp/velox/substrait/SubstraitParser.h @@ -28,29 +28,24 @@ #include +#include "velox/type/Type.h" + namespace gluten { /// This class contains some common functions used to parse Substrait /// components, and convert them into recognizable representations. class SubstraitParser { public: - /// Stores the type name and nullability. - struct SubstraitType { - std::string type; - bool nullable; - }; - /// Used to parse Substrait NamedStruct. - static std::vector> parseNamedStruct(const ::substrait::NamedStruct& namedStruct); + static std::vector parseNamedStruct( + const ::substrait::NamedStruct& namedStruct, + bool asLowerCase = false); /// Used to parse partition columns from Substrait NamedStruct. static std::vector parsePartitionColumns(const ::substrait::NamedStruct& namedStruct); - /// Parse Substrait Type. - static SubstraitType parseType(const ::substrait::Type& substraitType); - - // Parse substraitType type such as i32. - static std::string parseType(const std::string& substraitType); + /// Parse Substrait Type to Velox type. + static facebook::velox::TypePtr parseType(const ::substrait::Type& substraitType, bool asLowerCase = false); /// Parse Substrait ReferenceSegment. static int32_t parseReferenceSegment(const ::substrait::Expression::ReferenceSegment& refSegment); @@ -74,12 +69,11 @@ class SubstraitParser { /// specifications in Substrait yaml files. static std::string findFunctionSpec(const std::unordered_map& functionMap, uint64_t id); - /// Extracts the function name for a function from specified compound name. - /// When the input is a simple name, it will be returned. - static std::string getSubFunctionName(const std::string& functionSpec); + /// Extracts the name of a function by splitting signature with delimiter. + static std::string getNameBeforeDelimiter(const std::string& signature, const std::string& delimiter = ":"); /// This function is used get the types from the compound name. - static void getSubFunctionTypes(const std::string& subFuncSpec, std::vector& types); + static std::vector getSubFunctionTypes(const std::string& subFuncSpec); /// Used to find the Velox function name according to the function id /// from a pre-constructed function map. @@ -95,6 +89,9 @@ class SubstraitParser { /// @return Whether the config is set as true. static bool configSetInOptimization(const ::substrait::extensions::AdvancedExtension&, const std::string& config); + /// Extract input types from Substrait function signature. + static std::vector sigToTypes(const std::string& functionSig); + private: /// A map used for mapping Substrait function keywords into Velox functions' /// keywords. Key: the Substrait function keyword, Value: the Velox function diff --git a/cpp/velox/substrait/SubstraitToVeloxExpr.cc b/cpp/velox/substrait/SubstraitToVeloxExpr.cc index 15fca457cb47..e2ad05c81032 100644 --- a/cpp/velox/substrait/SubstraitToVeloxExpr.cc +++ b/cpp/velox/substrait/SubstraitToVeloxExpr.cc @@ -278,7 +278,7 @@ core::TypedExprPtr SubstraitVeloxExprConverter::toLambdaExpr( SubstraitParser::findVeloxFunction(functionMap_, arg.scalar_function().function_reference()); CHECK_EQ(veloxFunction, "namedlambdavariable"); argumentNames.emplace_back(arg.scalar_function().arguments(0).value().literal().string()); - argumentTypes.emplace_back(substraitTypeToVeloxType(substraitFunc.output_type())); + argumentTypes.emplace_back(SubstraitParser::parseType(substraitFunc.output_type())); } auto rowType = ROW(std::move(argumentNames), std::move(argumentTypes)); // Arg[0] -> function. @@ -296,16 +296,16 @@ core::TypedExprPtr SubstraitVeloxExprConverter::toVeloxExpr( params.emplace_back(toVeloxExpr(sArg.value(), inputType)); } const auto& veloxFunction = SubstraitParser::findVeloxFunction(functionMap_, substraitFunc.function_reference()); - std::string typeName = SubstraitParser::parseType(substraitFunc.output_type()).type; + const auto& outputType = SubstraitParser::parseType(substraitFunc.output_type()); if (veloxFunction == "lambdafunction") { return toLambdaExpr(substraitFunc, inputType); } else if (veloxFunction == "namedlambdavariable") { - return makeFieldAccessExpr(substraitFunc.arguments(0).value().literal().string(), toVeloxType(typeName), nullptr); + return makeFieldAccessExpr(substraitFunc.arguments(0).value().literal().string(), outputType, nullptr); } else if (veloxFunction == "extract") { - return toExtractExpr(std::move(params), toVeloxType(typeName)); + return toExtractExpr(std::move(params), outputType); } else { - return std::make_shared(toVeloxType(typeName), std::move(params), veloxFunction); + return std::make_shared(outputType, std::move(params), veloxFunction); } } @@ -410,7 +410,7 @@ std::shared_ptr SubstraitVeloxExprConverter::toVe } } case ::substrait::Expression_Literal::LiteralTypeCase::kNull: { - auto veloxType = substraitTypeToVeloxType(substraitLit.null()); + auto veloxType = SubstraitParser::parseType(substraitLit.null()); if (veloxType->isShortDecimal()) { return std::make_shared(veloxType, variant::null(TypeKind::BIGINT)); } else if (veloxType->isLongDecimal()) { @@ -473,7 +473,7 @@ VectorPtr SubstraitVeloxExprConverter::literalsToVector( case ::substrait::Expression_Literal::LiteralTypeCase::kVarChar: return constructFlatVector(elementAtFunc, childSize, VARCHAR(), pool_); case ::substrait::Expression_Literal::LiteralTypeCase::kNull: { - auto veloxType = substraitTypeToVeloxType(literal.null()); + auto veloxType = SubstraitParser::parseType(literal.null()); auto kind = veloxType->kind(); return VELOX_DYNAMIC_SCALAR_TYPE_DISPATCH(constructFlatVector, kind, elementAtFunc, childSize, veloxType, pool_); } @@ -537,7 +537,7 @@ RowVectorPtr SubstraitVeloxExprConverter::literalsToRowVector(const ::substrait: core::TypedExprPtr SubstraitVeloxExprConverter::toVeloxExpr( const ::substrait::Expression::Cast& castExpr, const RowTypePtr& inputType) { - auto type = substraitTypeToVeloxType(castExpr.type()); + auto type = SubstraitParser::parseType(castExpr.type()); bool nullOnFailure = isNullOnFailure(castExpr.failure_behavior()); std::vector inputs{toVeloxExpr(castExpr.input(), inputType)}; diff --git a/cpp/velox/substrait/SubstraitToVeloxPlan.cc b/cpp/velox/substrait/SubstraitToVeloxPlan.cc index d36401cbe2ca..a0f5797e33e9 100644 --- a/cpp/velox/substrait/SubstraitToVeloxPlan.cc +++ b/cpp/velox/substrait/SubstraitToVeloxPlan.cc @@ -355,11 +355,11 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait:: for (const auto& arg : aggFunction.arguments()) { aggParams.emplace_back(exprConverter_->toVeloxExpr(arg.value(), inputType)); } - auto aggVeloxType = substraitTypeToVeloxType(aggFunction.output_type()); + auto aggVeloxType = SubstraitParser::parseType(aggFunction.output_type()); auto aggExpr = std::make_shared(aggVeloxType, std::move(aggParams), funcName); - const auto& functionSpec = SubstraitParser::findFunctionSpec(functionMap_, aggFunction.function_reference()); - std::vector rawInputTypes = sigToTypes(functionSpec); + std::vector rawInputTypes = + SubstraitParser::sigToTypes(SubstraitParser::findFunctionSpec(functionMap_, aggFunction.function_reference())); aggregates.emplace_back(core::AggregationNode::Aggregate{aggExpr, rawInputTypes, mask, {}, {}}); } @@ -611,7 +611,7 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait:: for (const auto& arg : windowFunction.arguments()) { windowParams.emplace_back(exprConverter_->toVeloxExpr(arg.value(), inputType)); } - auto windowVeloxType = substraitTypeToVeloxType(windowFunction.output_type()); + auto windowVeloxType = SubstraitParser::parseType(windowFunction.output_type()); auto windowCall = std::make_shared(windowVeloxType, std::move(windowParams), funcName); auto upperBound = windowFunction.upper_bound(); auto lowerBound = windowFunction.lower_bound(); @@ -779,12 +779,8 @@ core::PlanNodePtr SubstraitToVeloxPlanConverter::toVeloxPlan(const ::substrait:: } colNameList.emplace_back(fieldName); } - auto substraitTypeList = SubstraitParser::parseNamedStruct(baseSchema); + veloxTypeList = SubstraitParser::parseNamedStruct(baseSchema, asLowerCase); isPartitionColumns = SubstraitParser::parsePartitionColumns(baseSchema); - veloxTypeList.reserve(substraitTypeList.size()); - for (const auto& substraitType : substraitTypeList) { - veloxTypeList.emplace_back(toVeloxType(substraitType->type, asLowerCase)); - } } // Parse local files and construct split info. @@ -1048,7 +1044,7 @@ void SubstraitToVeloxPlanConverter::flattenConditions( const auto& sFunc = substraitFilter.scalar_function(); auto filterNameSpec = SubstraitParser::findFunctionSpec(functionMap_, sFunc.function_reference()); // TODO: Only and relation is supported here. - if (SubstraitParser::getSubFunctionName(filterNameSpec) == "and") { + if (SubstraitParser::getNameBeforeDelimiter(filterNameSpec) == "and") { for (const auto& sCondition : sFunc.arguments()) { flattenConditions(sCondition.value(), scalarFunctions, singularOrLists, ifThens); } @@ -1114,7 +1110,7 @@ void SubstraitToVeloxPlanConverter::extractJoinKeys( auto visited = expressions.back(); expressions.pop_back(); if (visited->rex_type_case() == ::substrait::Expression::RexTypeCase::kScalarFunction) { - const auto& funcName = SubstraitParser::getSubFunctionName( + const auto& funcName = SubstraitParser::getNameBeforeDelimiter( SubstraitParser::findVeloxFunction(functionMap_, visited->scalar_function().function_reference())); const auto& args = visited->scalar_function().arguments(); if (funcName == "and") { @@ -1146,7 +1142,7 @@ connector::hive::SubfieldFilters SubstraitToVeloxPlanConverter::createSubfieldFi // Process scalarFunctions. for (const auto& scalarFunction : scalarFunctions) { auto filterNameSpec = SubstraitParser::findFunctionSpec(functionMap_, scalarFunction.function_reference()); - auto filterName = SubstraitParser::getSubFunctionName(filterNameSpec); + auto filterName = SubstraitParser::getNameBeforeDelimiter(filterNameSpec); if (filterName == sNot) { VELOX_CHECK(scalarFunction.arguments().size() == 1); @@ -1289,7 +1285,7 @@ bool SubstraitToVeloxPlanConverter::canPushdownNot( auto argFunction = SubstraitParser::findFunctionSpec(functionMap_, notArg.value().scalar_function().function_reference()); - auto functionName = SubstraitParser::getSubFunctionName(argFunction); + auto functionName = SubstraitParser::getNameBeforeDelimiter(argFunction); static const std::unordered_set supportedNotFunctions = {sGte, sGt, sLte, sLt, sEqual}; @@ -1318,7 +1314,7 @@ bool SubstraitToVeloxPlanConverter::canPushdownOr( if (arg.value().has_scalar_function()) { auto nameSpec = SubstraitParser::findFunctionSpec(functionMap_, arg.value().scalar_function().function_reference()); - auto functionName = SubstraitParser::getSubFunctionName(nameSpec); + auto functionName = SubstraitParser::getNameBeforeDelimiter(nameSpec); uint32_t fieldIdx; bool isFieldOrWithLiteral = fieldOrWithLiteral(arg.value().scalar_function().arguments(), fieldIdx); @@ -1372,7 +1368,7 @@ void SubstraitToVeloxPlanConverter::separateFilters( for (const auto& scalarFunction : scalarFunctions) { auto filterNameSpec = SubstraitParser::findFunctionSpec(functionMap_, scalarFunction.function_reference()); - auto filterName = SubstraitParser::getSubFunctionName(filterNameSpec); + auto filterName = SubstraitParser::getNameBeforeDelimiter(filterNameSpec); // Add all decimal filters to remaining functions because their pushdown are not supported. if (format == dwio::common::FileFormat::ORC && scalarFunction.arguments().size() > 0) { auto value = scalarFunction.arguments().at(0).value(); @@ -1507,7 +1503,7 @@ void SubstraitToVeloxPlanConverter::setFilterInfo( std::vector& columnToFilterInfo, bool reverse) { auto nameSpec = SubstraitParser::findFunctionSpec(functionMap_, scalarFunction.function_reference()); - auto functionName = SubstraitParser::getSubFunctionName(nameSpec); + auto functionName = SubstraitParser::getNameBeforeDelimiter(nameSpec); // Extract the column index and column bound from the scalar function. std::optional colIdx; diff --git a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc index d1dde3a7c77d..2d4247f650c3 100644 --- a/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc +++ b/cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc @@ -103,7 +103,7 @@ bool SubstraitToVeloxPlanValidator::validateInputTypes( const auto& sTypes = inputType.struct_().types(); for (const auto& sType : sTypes) { try { - types.emplace_back(substraitTypeToVeloxType(sType)); + types.emplace_back(SubstraitParser::parseType(sType)); } catch (const VeloxException& err) { logValidateMsg("native validation failed due to: Type is not supported, " + err.message()); return false; @@ -204,9 +204,8 @@ bool SubstraitToVeloxPlanValidator::validateScalarFunction( const auto& function = SubstraitParser::findFunctionSpec(planConverter_.getFunctionMap(), scalarFunction.function_reference()); - const auto& name = SubstraitParser::getSubFunctionName(function); - std::vector types; - SubstraitParser::getSubFunctionTypes(function, types); + const auto& name = SubstraitParser::getNameBeforeDelimiter(function); + std::vector types = SubstraitParser::getSubFunctionTypes(function); if (name == "round") { return validateRound(scalarFunction, inputType); @@ -293,7 +292,7 @@ bool SubstraitToVeloxPlanValidator::validateCast( return false; } - const auto& toType = substraitTypeToVeloxType(castExpr.type()); + const auto& toType = SubstraitParser::parseType(castExpr.type()); if (toType->kind() == TypeKind::TIMESTAMP) { logValidateMsg("native validation failed due to: Casting to TIMESTAMP is not supported."); return false; @@ -490,7 +489,7 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::WindowRel& windo try { const auto& windowFunction = smea.measure(); funcSpecs.emplace_back(planConverter_.findFuncSpec(windowFunction.function_reference())); - substraitTypeToVeloxType(windowFunction.output_type()); + SubstraitParser::parseType(windowFunction.output_type()); for (const auto& arg : windowFunction.arguments()) { auto typeCase = arg.value().rex_type_case(); switch (typeCase) { @@ -531,7 +530,7 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::WindowRel& windo // Validate supported aggregate functions. static const std::unordered_set unsupportedFuncs = {"collect_list", "collect_set"}; for (const auto& funcSpec : funcSpecs) { - auto funcName = SubstraitParser::getSubFunctionName(funcSpec); + auto funcName = SubstraitParser::getNameBeforeDelimiter(funcSpec); if (unsupportedFuncs.find(funcName) != unsupportedFuncs.end()) { logValidateMsg("native validation failed due to: " + funcName + " was not supported in WindowRel."); return false; @@ -853,7 +852,7 @@ bool SubstraitToVeloxPlanValidator::validateAggRelFunctionType(const ::substrait std::vector types; bool isDecimal = false; try { - types = sigToTypes(funcSpec); + types = SubstraitParser::sigToTypes(funcSpec); for (const auto& type : types) { if (!isDecimal && type->isDecimal()) { isDecimal = true; @@ -865,7 +864,7 @@ bool SubstraitToVeloxPlanValidator::validateAggRelFunctionType(const ::substrait err.message()); return false; } - auto funcName = SubstraitParser::mapToVeloxFunction(SubstraitParser::getSubFunctionName(funcSpec), isDecimal); + auto funcName = SubstraitParser::mapToVeloxFunction(SubstraitParser::getNameBeforeDelimiter(funcSpec), isDecimal); auto signaturesOpt = exec::getAggregateFunctionSignatures(funcName); if (!signaturesOpt) { logValidateMsg( @@ -954,9 +953,9 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::AggregateRel& ag const auto& aggFunction = smea.measure(); const auto& functionSpec = planConverter_.findFuncSpec(aggFunction.function_reference()); funcSpecs.emplace_back(functionSpec); - substraitTypeToVeloxType(aggFunction.output_type()); + SubstraitParser::parseType(aggFunction.output_type()); // Validate the size of arguments. - if (SubstraitParser::getSubFunctionName(functionSpec) == "count" && aggFunction.arguments().size() > 1) { + if (SubstraitParser::getNameBeforeDelimiter(functionSpec) == "count" && aggFunction.arguments().size() > 1) { logValidateMsg("native validation failed due to: count should have only one argument"); // Count accepts only one argument. return false; @@ -1024,7 +1023,7 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::AggregateRel& ag "approx_distinct"}; for (const auto& funcSpec : funcSpecs) { - auto funcName = SubstraitParser::getSubFunctionName(funcSpec); + auto funcName = SubstraitParser::getNameBeforeDelimiter(funcSpec); if (supportedAggFuncs.find(funcName) == supportedAggFuncs.end()) { logValidateMsg("native validation failed due to: " + funcName + " was not supported in AggregateRel."); return false; @@ -1066,11 +1065,7 @@ bool SubstraitToVeloxPlanValidator::validate(const ::substrait::ReadRel& readRel std::vector veloxTypeList; if (readRel.has_base_schema()) { const auto& baseSchema = readRel.base_schema(); - auto substraitTypeList = SubstraitParser::parseNamedStruct(baseSchema); - veloxTypeList.reserve(substraitTypeList.size()); - for (const auto& substraitType : substraitTypeList) { - veloxTypeList.emplace_back(toVeloxType(substraitType->type)); - } + veloxTypeList = SubstraitParser::parseNamedStruct(baseSchema); } int32_t inputPlanNodeId = 0; diff --git a/cpp/velox/substrait/TypeUtils.cc b/cpp/velox/substrait/TypeUtils.cc deleted file mode 100644 index 0d0d1ac9bed5..000000000000 --- a/cpp/velox/substrait/TypeUtils.cc +++ /dev/null @@ -1,229 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "TypeUtils.h" -#include "SubstraitParser.h" -#include "velox/type/Type.h" - -namespace gluten { -std::vector getTypesFromCompoundName(std::string_view compoundName) { - // CompoundName is like ARRAY or MAP - // or ROW,ROW> - // the position of then delimiter is where the number of leftAngleBracket - // equals rightAngleBracket need to split. - std::vector types; - std::vector angleBracketNumEqualPos; - auto leftAngleBracketPos = compoundName.find("<"); - auto rightAngleBracketPos = compoundName.rfind(">"); - auto typesName = compoundName.substr(leftAngleBracketPos + 1, rightAngleBracketPos - leftAngleBracketPos - 1); - int leftAngleBracketNum = 0; - int rightAngleBracketNum = 0; - for (auto index = 0; index < typesName.length(); index++) { - if (typesName[index] == '<') { - leftAngleBracketNum++; - } - if (typesName[index] == '>') { - rightAngleBracketNum++; - } - if (typesName[index] == ',' && rightAngleBracketNum == leftAngleBracketNum) { - angleBracketNumEqualPos.push_back(index); - } - } - int startPos = 0; - for (auto delimeterPos : angleBracketNumEqualPos) { - types.emplace_back(typesName.substr(startPos, delimeterPos - startPos)); - startPos = delimeterPos + 1; - } - types.emplace_back(std::string_view(typesName.data() + startPos, typesName.length() - startPos)); - return types; -} - -// TODO Refactor using Bison. -std::string_view getNameBeforeDelimiter(const std::string& compoundName, const std::string& delimiter) { - std::size_t pos = compoundName.find(delimiter); - if (pos == std::string::npos) { - return compoundName; - } - return std::string_view(compoundName.data(), pos); -} - -std::pair getPrecisionAndScale(const std::string& typeName) { - std::size_t start = typeName.find_first_of("<"); - std::size_t end = typeName.find_last_of(">"); - if (start == std::string::npos || end == std::string::npos) { - throw std::runtime_error("Invalid decimal type."); - } - - std::string decimalType = typeName.substr(start + 1, end - start - 1); - std::size_t token_pos = decimalType.find_first_of(","); - auto precision = stoi(decimalType.substr(0, token_pos)); - auto scale = stoi(decimalType.substr(token_pos + 1, decimalType.length() - 1)); - return std::make_pair(precision, scale); -} - -TypePtr toVeloxType(const std::string& typeName, bool asLowerCase) { - VELOX_CHECK(!typeName.empty(), "Cannot convert empty string to Velox type."); - auto type = std::string(getNameBeforeDelimiter(typeName, "<")); - if (DATE()->toString() == type) { - return DATE(); - } - if (type == "SHORT_DECIMAL") { - auto decimal = getPrecisionAndScale(typeName); - return DECIMAL(decimal.first, decimal.second); - } - auto typeKind = mapNameToTypeKind(type); - switch (typeKind) { - case TypeKind::BOOLEAN: - return BOOLEAN(); - case TypeKind::TINYINT: - return TINYINT(); - case TypeKind::SMALLINT: - return SMALLINT(); - case TypeKind::INTEGER: - return INTEGER(); - case TypeKind::BIGINT: - return BIGINT(); - case TypeKind::HUGEINT: { - auto decimal = getPrecisionAndScale(typeName); - return DECIMAL(decimal.first, decimal.second); - } - case TypeKind::REAL: - return REAL(); - case TypeKind::DOUBLE: - return DOUBLE(); - case TypeKind::VARCHAR: - return VARCHAR(); - case TypeKind::VARBINARY: - return VARBINARY(); - case TypeKind::ARRAY: { - auto fieldTypes = getTypesFromCompoundName(typeName); - VELOX_CHECK_EQ(fieldTypes.size(), 1, "The size of ARRAY type should be only one."); - return ARRAY(toVeloxType(std::string(fieldTypes[0]), asLowerCase)); - } - case TypeKind::MAP: { - auto fieldTypes = getTypesFromCompoundName(typeName); - VELOX_CHECK_EQ(fieldTypes.size(), 2, "The size of MAP type should be two."); - auto keyType = toVeloxType(std::string(fieldTypes[0]), asLowerCase); - auto valueType = toVeloxType(std::string(fieldTypes[1]), asLowerCase); - return MAP(keyType, valueType); - } - case TypeKind::ROW: { - auto fieldTypes = getTypesFromCompoundName(typeName); - VELOX_CHECK(!fieldTypes.empty(), "Converting empty ROW type from Substrait to Velox is not supported."); - - std::vector types; - std::vector names; - for (int idx = 0; idx < fieldTypes.size(); idx++) { - std::string fieldTypeAndName = std::string(fieldTypes[idx]); - size_t pos = fieldTypeAndName.find_last_of(':'); - if (pos == std::string::npos) { - // Name does not exist. - types.emplace_back(toVeloxType(fieldTypeAndName, asLowerCase)); - names.emplace_back("col_" + std::to_string(idx)); - } else { - types.emplace_back(toVeloxType(fieldTypeAndName.substr(0, pos), asLowerCase)); - std::string fieldName = fieldTypeAndName.substr(pos + 1, fieldTypeAndName.length()); - if (asLowerCase) { - folly::toLowerAscii(fieldName); - } - names.emplace_back(fieldName); - } - } - return ROW(std::move(names), std::move(types)); - } - case TypeKind::TIMESTAMP: { - return TIMESTAMP(); - } - case TypeKind::UNKNOWN: - return UNKNOWN(); - default: - VELOX_NYI("Velox type conversion not supported for type {}.", typeName); - } -} - -TypePtr substraitTypeToVeloxType(const std::string& substraitType) { - return toVeloxType(SubstraitParser::parseType(substraitType)); -} - -TypePtr substraitTypeToVeloxType(const ::substrait::Type& substraitType) { - return toVeloxType(SubstraitParser::parseType(substraitType).type); -} - -TypePtr getRowType(const std::string& structType) { - // Struct info is in the format of struct. - // TODO: nested struct is not supported. - auto structStart = structType.find_first_of('<'); - auto structEnd = structType.find_last_of('>'); - VELOX_CHECK( - structEnd - structStart > 1, "native validation failed due to: More information is needed to create RowType"); - std::string childrenTypes = structType.substr(structStart + 1, structEnd - structStart - 1); - - // Split the types with delimiter. - std::string delimiter = ","; - std::size_t pos; - std::vector types; - std::vector names; - while ((pos = childrenTypes.find(delimiter)) != std::string::npos) { - const auto& typeStr = childrenTypes.substr(0, pos); - std::string decDelimiter = ">"; - if (typeStr.find("dec") != std::string::npos) { - std::size_t endPos = childrenTypes.find(decDelimiter); - VELOX_CHECK(endPos >= pos + 1, "Decimal scale is expected."); - const auto& decimalStr = typeStr + childrenTypes.substr(pos, endPos - pos) + decDelimiter; - types.emplace_back(getDecimalType(decimalStr)); - names.emplace_back(""); - childrenTypes.erase(0, endPos + delimiter.length() + decDelimiter.length()); - continue; - } - - types.emplace_back(substraitTypeToVeloxType(typeStr)); - names.emplace_back(""); - childrenTypes.erase(0, pos + delimiter.length()); - } - types.emplace_back(substraitTypeToVeloxType(childrenTypes)); - names.emplace_back(""); - return std::make_shared(std::move(names), std::move(types)); -} - -TypePtr getDecimalType(const std::string& decimalType) { - // Decimal info is in the format of dec. - auto precisionStart = decimalType.find_first_of('<'); - auto tokenIndex = decimalType.find_first_of(','); - auto scaleStart = decimalType.find_first_of('>'); - auto precision = stoi(decimalType.substr(precisionStart + 1, (tokenIndex - precisionStart - 1))); - auto scale = stoi(decimalType.substr(tokenIndex + 1, (scaleStart - tokenIndex - 1))); - return DECIMAL(precision, scale); -} - -std::vector sigToTypes(const std::string& functionSig) { - std::vector typeStrs; - SubstraitParser::getSubFunctionTypes(functionSig, typeStrs); - std::vector types; - types.reserve(typeStrs.size()); - for (const auto& typeStr : typeStrs) { - if (typeStr.find("struct") != std::string::npos) { - types.emplace_back(getRowType(typeStr)); - } else if (typeStr.find("dec") != std::string::npos) { - types.emplace_back(getDecimalType(typeStr)); - } else { - types.emplace_back(substraitTypeToVeloxType(typeStr)); - } - } - return types; -} - -} // namespace gluten diff --git a/cpp/velox/substrait/TypeUtils.h b/cpp/velox/substrait/TypeUtils.h index 7266dbd9a236..b2aaf725789d 100644 --- a/cpp/velox/substrait/TypeUtils.h +++ b/cpp/velox/substrait/TypeUtils.h @@ -22,30 +22,6 @@ using namespace facebook::velox; namespace gluten { - -#ifndef TOVELOXTYPE_H -#define TOVELOXTYPE_H - -/// Return the Velox type according to the typename. -TypePtr toVeloxType(const std::string& typeName, bool asLowerCase = false); - -/// Return the Velox type according to substrait type string. -TypePtr substraitTypeToVeloxType(const std::string& substraitType); - -/// Return the Velox type according to substrait type. -TypePtr substraitTypeToVeloxType(const ::substrait::Type& substraitType); - -/// Create RowType based on the type information in string. -TypePtr getRowType(const std::string& structType); - -/// Create DecimalType based on the type information in string. -TypePtr getDecimalType(const std::string& decimalType); - -std::vector sigToTypes(const std::string& functionSig); - -#endif /* TOVELOXTYPE_H */ - -std::string_view getNameBeforeDelimiter(const std::string& compoundName, const std::string& delimiter); #ifndef RANGETRAITS_H #define RANGETRAITS_H diff --git a/cpp/velox/substrait/VeloxSubstraitSignature.cc b/cpp/velox/substrait/VeloxSubstraitSignature.cc index 4606132586c2..ef1055f582cb 100644 --- a/cpp/velox/substrait/VeloxSubstraitSignature.cc +++ b/cpp/velox/substrait/VeloxSubstraitSignature.cc @@ -60,6 +60,105 @@ std::string VeloxSubstraitSignature::toSubstraitSignature(const TypePtr& type) { } } +TypePtr VeloxSubstraitSignature::fromSubstraitSignature(const std::string& signature) { + if (signature == "bool") { + return BOOLEAN(); + } + + if (signature == "i8") { + return TINYINT(); + } + + if (signature == "i16") { + return SMALLINT(); + } + + if (signature == "i32") { + return INTEGER(); + } + + if (signature == "i64") { + return BIGINT(); + } + + if (signature == "fp32") { + return REAL(); + } + + if (signature == "fp64") { + return DOUBLE(); + } + + if (signature == "str") { + return VARCHAR(); + } + + if (signature == "vbin") { + return VARBINARY(); + } + + if (signature == "ts") { + return TIMESTAMP(); + } + + if (signature == "date") { + return DATE(); + } + + auto startWith = [](const std::string& str, const std::string& prefix) { + return str.size() >= prefix.size() && str.substr(0, prefix.size()) == prefix; + }; + + if (startWith(signature, "dec")) { + // Decimal type name is in the format of dec. + auto precisionStart = signature.find_first_of('<'); + auto tokenIndex = signature.find_first_of(','); + auto scaleEnd = signature.find_first_of('>'); + auto precision = stoi(signature.substr(precisionStart + 1, (tokenIndex - precisionStart - 1))); + auto scale = stoi(signature.substr(tokenIndex + 1, (scaleEnd - tokenIndex - 1))); + return DECIMAL(precision, scale); + } + + if (startWith(signature, "struct")) { + // Struct type name is in the format of struct. + auto structStart = signature.find_first_of('<'); + auto structEnd = signature.find_last_of('>'); + VELOX_CHECK( + structEnd - structStart > 1, "Native validation failed due to: more information is needed to create RowType"); + std::string childrenTypes = signature.substr(structStart + 1, structEnd - structStart - 1); + + // Split the types with delimiter. + std::string delimiter = ","; + std::size_t pos; + std::vector types; + std::vector names; + while ((pos = childrenTypes.find(delimiter)) != std::string::npos) { + auto typeStr = childrenTypes.substr(0, pos); + std::size_t endPos = pos; + if (startWith(typeStr, "dec") || startWith(typeStr, "struct")) { + endPos = childrenTypes.find(">") + 1; + if (endPos > pos) { + typeStr += childrenTypes.substr(pos, endPos - pos); + } else { + // For nested case, the end '>' could missing, + // so the last position is treated as end. + typeStr += childrenTypes.substr(pos); + endPos = childrenTypes.size(); + } + } + types.emplace_back(fromSubstraitSignature(typeStr)); + names.emplace_back(""); + childrenTypes.erase(0, endPos + delimiter.length()); + } + if (childrenTypes.size() > 0 && !startWith(childrenTypes, ">")) { + types.emplace_back(fromSubstraitSignature(childrenTypes)); + names.emplace_back(""); + } + return std::make_shared(std::move(names), std::move(types)); + } + VELOX_UNSUPPORTED("Substrait type signature conversion to Velox type not supported for {}.", signature); +} + std::string VeloxSubstraitSignature::toSubstraitSignature( const std::string& functionName, const std::vector& arguments) { diff --git a/cpp/velox/substrait/VeloxSubstraitSignature.h b/cpp/velox/substrait/VeloxSubstraitSignature.h index 8091b6cc5fd0..8a54e4edcb6e 100644 --- a/cpp/velox/substrait/VeloxSubstraitSignature.h +++ b/cpp/velox/substrait/VeloxSubstraitSignature.h @@ -34,6 +34,9 @@ class VeloxSubstraitSignature { /// https://substrait.io/extensions/#function-signature-compound-names. static std::string toSubstraitSignature(const TypePtr& type); + /// Given a substrait type signature, return the Velox type. + static TypePtr fromSubstraitSignature(const std::string& signature); + /// Given a velox scalar function name and argument types, return the /// substrait function signature. static std::string toSubstraitSignature(const std::string& functionName, const std::vector& arguments); diff --git a/cpp/velox/tests/FunctionTest.cc b/cpp/velox/tests/FunctionTest.cc index e3e415bf574e..a5d2b5fade3d 100644 --- a/cpp/velox/tests/FunctionTest.cc +++ b/cpp/velox/tests/FunctionTest.cc @@ -30,6 +30,8 @@ #include "velox/core/QueryCtx.h" +#include "substrait/SubstraitParser.h" + using namespace facebook::velox; using namespace facebook::velox::test; @@ -72,15 +74,15 @@ TEST_F(FunctionTest, getIdxFromNodeName) { TEST_F(FunctionTest, getNameBeforeDelimiter) { std::string functionSpec = "lte:fp64_fp64"; - std::string_view funcName = getNameBeforeDelimiter(functionSpec, ":"); + std::string_view funcName = SubstraitParser::getNameBeforeDelimiter(functionSpec); ASSERT_EQ(funcName, "lte"); functionSpec = "lte:"; - funcName = getNameBeforeDelimiter(functionSpec, ":"); + funcName = SubstraitParser::getNameBeforeDelimiter(functionSpec); ASSERT_EQ(funcName, "lte"); functionSpec = "lte"; - funcName = getNameBeforeDelimiter(functionSpec, ":"); + funcName = SubstraitParser::getNameBeforeDelimiter(functionSpec); ASSERT_EQ(funcName, "lte"); } @@ -180,18 +182,36 @@ TEST_F(FunctionTest, setVectorFromVariants) { } TEST_F(FunctionTest, getFunctionType) { - std::vector types; - SubstraitParser::getSubFunctionTypes("sum:opt_i32", types); + std::vector types = SubstraitParser::getSubFunctionTypes("sum:opt_i32"); ASSERT_EQ("i32", types[0]); - types.clear(); - SubstraitParser::getSubFunctionTypes("sum:i32", types); + types = SubstraitParser::getSubFunctionTypes("sum:i32"); ASSERT_EQ("i32", types[0]); - types.clear(); - SubstraitParser::getSubFunctionTypes("sum:opt_str_str", types); + types = SubstraitParser::getSubFunctionTypes("sum:opt_str_str"); ASSERT_EQ(2, types.size()); ASSERT_EQ("str", types[0]); ASSERT_EQ("str", types[1]); } + +TEST_F(FunctionTest, sigToTypes) { + std::vector types = SubstraitParser::sigToTypes("sum:opt_i32"); + ASSERT_EQ(types[0]->kind(), TypeKind::INTEGER); + + types = SubstraitParser::sigToTypes("and:opt_bool_bool"); + ASSERT_EQ(2, types.size()); + ASSERT_EQ(types[0]->kind(), TypeKind::BOOLEAN); + ASSERT_EQ(types[1]->kind(), TypeKind::BOOLEAN); + + types = SubstraitParser::sigToTypes("sum:dec<12,9>"); + ASSERT_EQ(getDecimalPrecisionScale(*types[0]).first, 12); + ASSERT_EQ(getDecimalPrecisionScale(*types[0]).second, 9); + + types = SubstraitParser::sigToTypes("sum:struct,bool>"); + ASSERT_EQ(types[0]->kind(), TypeKind::ROW); + ASSERT_EQ(types[0]->childAt(0)->kind(), TypeKind::INTEGER); + ASSERT_EQ(types[0]->childAt(1)->kind(), TypeKind::VARCHAR); + ASSERT_TRUE(types[0]->childAt(2)->isDecimal()); + ASSERT_EQ(types[0]->childAt(3)->kind(), TypeKind::BOOLEAN); +} } // namespace gluten diff --git a/cpp/velox/tests/VeloxSubstraitSignatureTest.cc b/cpp/velox/tests/VeloxSubstraitSignatureTest.cc index 2c539d0de7ef..fbfe14f7c92c 100644 --- a/cpp/velox/tests/VeloxSubstraitSignatureTest.cc +++ b/cpp/velox/tests/VeloxSubstraitSignatureTest.cc @@ -30,6 +30,10 @@ class VeloxSubstraitSignatureTest : public ::testing::Test { functions::prestosql::registerAllScalarFunctions(); } + static TypePtr fromSubstraitSignature(const std::string& signature) { + return VeloxSubstraitSignature::fromSubstraitSignature(signature); + } + static std::string toSubstraitSignature(const TypePtr& type) { return VeloxSubstraitSignature::toSubstraitSignature(type); } @@ -91,4 +95,50 @@ TEST_F(VeloxSubstraitSignatureTest, toSubstraitSignatureWithFunctionNameAndArgum ASSERT_ANY_THROW(toSubstraitSignature("transform_keys", std::move(types))); } +TEST_F(VeloxSubstraitSignatureTest, fromSubstraitSignature) { + ASSERT_EQ(fromSubstraitSignature("bool")->kind(), TypeKind::BOOLEAN); + ASSERT_EQ(fromSubstraitSignature("i8")->kind(), TypeKind::TINYINT); + ASSERT_EQ(fromSubstraitSignature("i16")->kind(), TypeKind::SMALLINT); + ASSERT_EQ(fromSubstraitSignature("i32")->kind(), TypeKind::INTEGER); + ASSERT_EQ(fromSubstraitSignature("i64")->kind(), TypeKind::BIGINT); + ASSERT_EQ(fromSubstraitSignature("fp32")->kind(), TypeKind::REAL); + ASSERT_EQ(fromSubstraitSignature("fp64")->kind(), TypeKind::DOUBLE); + ASSERT_EQ(fromSubstraitSignature("str")->kind(), TypeKind::VARCHAR); + ASSERT_EQ(fromSubstraitSignature("vbin")->kind(), TypeKind::VARBINARY); + ASSERT_EQ(fromSubstraitSignature("ts")->kind(), TypeKind::TIMESTAMP); + ASSERT_EQ(fromSubstraitSignature("date")->kind(), TypeKind::INTEGER); + ASSERT_EQ(fromSubstraitSignature("dec<18,2>")->kind(), TypeKind::BIGINT); + ASSERT_EQ(fromSubstraitSignature("dec<19,2>")->kind(), TypeKind::HUGEINT); + + // Struct type test. + auto type = fromSubstraitSignature("struct>"); + ASSERT_EQ(type->kind(), TypeKind::ROW); + ASSERT_EQ(type->childAt(0)->kind(), TypeKind::BOOLEAN); + ASSERT_EQ(type->childAt(1)->kind(), TypeKind::VARBINARY); + ASSERT_EQ(type->childAt(2)->kind(), TypeKind::BIGINT); + type = fromSubstraitSignature("struct>"); + ASSERT_EQ(type->childAt(1)->kind(), TypeKind::ROW); + ASSERT_EQ(type->childAt(1)->childAt(0)->kind(), TypeKind::TINYINT); + ASSERT_EQ(type->childAt(1)->childAt(1)->kind(), TypeKind::REAL); + type = fromSubstraitSignature("struct,vbin,ts,dec<9,2>>"); + ASSERT_EQ(type->childAt(1)->kind(), TypeKind::ROW); + type = fromSubstraitSignature("struct,i16>"); + ASSERT_EQ(type->childAt(0)->kind(), TypeKind::ROW); + ASSERT_EQ(type->childAt(0)->childAt(0)->kind(), TypeKind::TIMESTAMP); + type = fromSubstraitSignature("struct>>"); + ASSERT_EQ(type->childAt(0)->kind(), TypeKind::ROW); + ASSERT_EQ(type->childAt(0)->childAt(0)->kind(), TypeKind::BIGINT); + type = fromSubstraitSignature("struct>>"); + ASSERT_EQ(type->kind(), TypeKind::ROW); + ASSERT_EQ(type->childAt(0)->kind(), TypeKind::ROW); + ASSERT_EQ(type->childAt(0)->childAt(0)->kind(), TypeKind::ROW); + ASSERT_EQ(type->childAt(0)->childAt(0)->childAt(0)->kind(), TypeKind::TINYINT); + type = fromSubstraitSignature("struct>>"); + ASSERT_EQ(type->childAt(0)->childAt(0)->childAt(0)->kind(), TypeKind::TINYINT); + ASSERT_EQ(type->childAt(0)->childAt(0)->childAt(1)->kind(), TypeKind::VARCHAR); + type = fromSubstraitSignature("struct>>>"); + ASSERT_EQ(type->childAt(0)->childAt(0)->childAt(1)->kind(), TypeKind::HUGEINT); + ASSERT_ANY_THROW(fromSubstraitSignature("other")->kind()); +} + } // namespace gluten diff --git a/cpp/velox/tests/VeloxToSubstraitTypeTest.cc b/cpp/velox/tests/VeloxToSubstraitTypeTest.cc index c9b95c711332..bc9e9df05c05 100644 --- a/cpp/velox/tests/VeloxToSubstraitTypeTest.cc +++ b/cpp/velox/tests/VeloxToSubstraitTypeTest.cc @@ -32,7 +32,7 @@ class VeloxToSubstraitTypeTest : public ::testing::Test { google::protobuf::Arena arena; auto substraitType = typeConvertor_->toSubstraitType(arena, type); - auto sameType = substraitTypeToVeloxType(substraitType); + auto sameType = SubstraitParser::parseType(substraitType); ASSERT_TRUE(sameType->kindEquals(type)) << "Expected: " << type->toString() << ", but got: " << sameType->toString(); } @@ -59,6 +59,6 @@ TEST_F(VeloxToSubstraitTypeTest, basic) { testTypeConversion(ROW({"a", "b", "c"}, {BIGINT(), BOOLEAN(), VARCHAR()})); testTypeConversion(ROW({"a", "b", "c"}, {BIGINT(), ROW({"x", "y"}, {BOOLEAN(), VARCHAR()}), REAL()})); - ASSERT_ANY_THROW(testTypeConversion(ROW({}, {}))); + testTypeConversion(ROW({}, {})); } } // namespace gluten