Skip to content

Commit

Permalink
[GIE Compiler] fix unit test bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
shirly121 committed Oct 19, 2023
1 parent 88737d9 commit a28c0e1
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,19 @@

package com.alibaba.graphscope.common.ir.meta.procedure;

import com.alibaba.graphscope.common.ir.type.ArbitraryArrayType;
import com.alibaba.graphscope.common.ir.type.ArbitraryMapType;
import com.alibaba.graphscope.common.ir.type.GraphTypeFactoryImpl;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;

import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.sql.type.SqlTypeName;

import java.util.List;
import java.util.stream.Collectors;

public class Utils {
public static String typeToStr(RelDataType dataType) {
SqlTypeName typeName = dataType.getSqlTypeName();
Expand All @@ -30,8 +37,49 @@ public static String typeToStr(RelDataType dataType) {
} else if (typeName == SqlTypeName.BIGINT) {
return "LONG";
} else if (typeName == SqlTypeName.ARRAY || typeName == SqlTypeName.MULTISET) {
return String.format(
"%s(%s)", typeName.getName(), typeToStr(dataType.getComponentType()));
if (dataType instanceof ArbitraryArrayType) {
List<RelDataType> componentTypes =
((ArbitraryArrayType) dataType).getComponentTypes();
StringBuilder sb = new StringBuilder();
sb.append(typeName.getName() + "(");
for (int i = 0; i < componentTypes.size(); i++) {
sb.append(typeToStr(componentTypes.get(i)));
if (i != componentTypes.size() - 1) {
sb.append(",");
}
}
sb.append(")");
return sb.toString();
} else {
return String.format(
"%s(%s)", typeName.getName(), typeToStr(dataType.getComponentType()));
}
} else if (typeName == SqlTypeName.MAP) {
if (dataType instanceof ArbitraryMapType) {
List<RelDataType> keyTypes = ((ArbitraryMapType) dataType).getKeyTypes();
List<RelDataType> valueTypes = ((ArbitraryMapType) dataType).getValueTypes();
Preconditions.checkArgument(
keyTypes.size() == valueTypes.size(),
"key size and value size are not equal in " + dataType);
StringBuilder sb = new StringBuilder();
sb.append(typeName.getName() + "(");
for (int i = 0; i < keyTypes.size(); i++) {
sb.append(typeToStr(keyTypes.get(i)));
sb.append(",");
sb.append(typeToStr(valueTypes.get(i)));
if (i != keyTypes.size() - 1) {
sb.append(",");
}
}
sb.append(")");
return sb.toString();
} else {
return String.format(
"%s(%s,%s)",
typeName.getName(),
typeToStr(dataType.getKeyType()),
typeToStr(dataType.getValueType()));
}
} else {
// todo: convert vertex or edge type to string
return typeName.getName();
Expand All @@ -45,21 +93,51 @@ public static RelDataType strToType(String typeStr, RelDataTypeFactory typeFacto
} else if (typeStr.equals("LONG")) {
return typeFactory.createSqlType(SqlTypeName.BIGINT);
} else if (typeStr.startsWith(SqlTypeName.ARRAY.getName())) {
RelDataType componentType = strToType(getComponentTypeStr(typeStr), typeFactory);
return typeFactory.createArrayType(componentType, -1);
List<String> componentTypeStr = getComponentTypeStr(typeStr);
if (componentTypeStr.size() == 1) {
RelDataType componentType = strToType(componentTypeStr.get(0), typeFactory);
return typeFactory.createArrayType(componentType, -1);
} else {
List<RelDataType> componentTypes =
componentTypeStr.stream()
.map(k -> strToType(k, typeFactory))
.collect(Collectors.toList());
return ((GraphTypeFactoryImpl) typeFactory)
.createArbitraryArrayType(componentTypes, false);
}
} else if (typeStr.startsWith(SqlTypeName.MAP.getName())) {
List<String> componentTypeStr = getComponentTypeStr(typeStr);
if (componentTypeStr.size() == 2) {
RelDataType keyType = strToType(componentTypeStr.get(0), typeFactory);
RelDataType valueType = strToType(componentTypeStr.get(1), typeFactory);
return typeFactory.createMapType(keyType, valueType);
} else {
List<RelDataType> keyTypes = Lists.newArrayList();
List<RelDataType> valueTypes = Lists.newArrayList();
for (int i = 0; i < componentTypeStr.size(); i++) {
if ((i & 1) == 0) {
keyTypes.add(strToType(componentTypeStr.get(i), typeFactory));
} else {
valueTypes.add(strToType(componentTypeStr.get(i), typeFactory));
}
}
return ((GraphTypeFactoryImpl) typeFactory)
.createArbitraryMapType(keyTypes, valueTypes, false);
}
} else if (typeStr.startsWith(SqlTypeName.MULTISET.getName())) {
RelDataType componentType = strToType(getComponentTypeStr(typeStr), typeFactory);
RelDataType componentType = strToType(getComponentTypeStr(typeStr).get(0), typeFactory);
return typeFactory.createMultisetType(componentType, -1);
} else {
return typeFactory.createSqlType(SqlTypeName.valueOf(typeStr));
}
}

private static String getComponentTypeStr(String typeStr) {
private static List<String> getComponentTypeStr(String typeStr) {
int leftBraceIdx = typeStr.indexOf("(");
int rightBraceIdx = typeStr.indexOf(")");
int rightBraceIdx = typeStr.lastIndexOf(")");
Preconditions.checkArgument(
leftBraceIdx != -1 && rightBraceIdx != -1, "invalid type pattern " + typeStr);
return typeStr.substring(leftBraceIdx + 1, rightBraceIdx);
return com.alibaba.graphscope.common.config.Utils.convertDotString(
typeStr.substring(leftBraceIdx + 1, rightBraceIdx));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@

import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.sql.type.ArraySqlType;
import org.apache.calcite.sql.type.MapSqlType;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.NotImplementedException;
Expand Down Expand Up @@ -57,7 +55,7 @@ public CypherRecordParser(RelDataType outputType) {

@Override
public List<AnyValue> parseFrom(IrResult.Record record) {
logger.debug("record {}", record);
logger.info("record {}", record);
Preconditions.checkArgument(
record.getColumnsCount() == outputType.getFieldCount(),
"column size of results "
Expand Down Expand Up @@ -85,22 +83,22 @@ protected AnyValue parseEntry(IrResult.Entry entry, @Nullable RelDataType dataTy
switch (dataType.getSqlTypeName()) {
case MULTISET:
case ARRAY:
if (dataType instanceof ArraySqlType) {
return parseCollection(entry.getCollection(), dataType.getComponentType());
} else if (dataType instanceof ArbitraryArrayType) {
if (dataType instanceof ArbitraryArrayType) {
return parseCollection(
entry.getCollection(),
((ArbitraryArrayType) dataType).getComponentTypes());
} else {
return parseCollection(entry.getCollection(), dataType.getComponentType());
}
case MAP:
if (dataType instanceof MapSqlType) {
return parseKeyValues(
entry.getMap(), dataType.getKeyType(), dataType.getValueType());
} else if (dataType instanceof ArbitraryMapType) {
if (dataType instanceof ArbitraryMapType) {
return parseKeyValues(
entry.getMap(),
((ArbitraryMapType) dataType).getKeyTypes(),
((ArbitraryMapType) dataType).getValueTypes());
} else {
return parseKeyValues(
entry.getMap(), dataType.getKeyType(), dataType.getValueType());
}
default:
return parseElement(entry.getElement(), dataType);
Expand All @@ -122,6 +120,7 @@ protected AnyValue parseElement(IrResult.Element element, @Nullable RelDataType
}

protected AnyValue parseCollection(IrResult.Collection collection, RelDataType componentType) {
logger.info("collection {}", collection);
switch (componentType.getSqlTypeName()) {
case BOOLEAN:
Boolean[] boolObjs =
Expand Down

0 comments on commit a28c0e1

Please sign in to comment.