diff --git a/flytepropeller/pkg/compiler/validators/utils.go b/flytepropeller/pkg/compiler/validators/utils.go index cc7245a6e2..5f41a6e65e 100644 --- a/flytepropeller/pkg/compiler/validators/utils.go +++ b/flytepropeller/pkg/compiler/validators/utils.go @@ -162,12 +162,51 @@ func UnionDistinctVariableMaps(m1, m2 map[string]*core.Variable) (map[string]*co return res, nil } +func buildMultipleTypeUnion(innerType []*core.LiteralType) *core.LiteralType { + var variants []*core.LiteralType + isNested := false + + for _, x := range innerType { + unionType := x.GetCollectionType().GetUnionType() + if unionType != nil { + isNested = true + variants = append(variants, unionType.Variants...) + } else { + variants = append(variants, x) + } + } + unionLiteralType := &core.LiteralType{ + Type: &core.LiteralType_UnionType{ + UnionType: &core.UnionType{ + Variants: variants, + }, + }, + } + + if isNested { + return &core.LiteralType{ + Type: &core.LiteralType_CollectionType{ + CollectionType: unionLiteralType, + }, + } + } + + return unionLiteralType +} + func literalTypeForLiterals(literals []*core.Literal) *core.LiteralType { innerType := make([]*core.LiteralType, 0, 1) innerTypeSet := sets.NewString() + var noneType *core.LiteralType for _, x := range literals { otherType := LiteralTypeForLiteral(x) otherTypeKey := otherType.String() + if _, ok := x.GetValue().(*core.Literal_Collection); ok { + if x.GetCollection().GetLiterals() == nil { + noneType = otherType + continue + } + } if !innerTypeSet.Has(otherTypeKey) { innerType = append(innerType, otherType) @@ -175,6 +214,11 @@ func literalTypeForLiterals(literals []*core.Literal) *core.LiteralType { } } + // only add none type if there aren't other types + if len(innerType) == 0 && noneType != nil { + innerType = append(innerType, noneType) + } + if len(innerType) == 0 { return &core.LiteralType{ Type: &core.LiteralType_Simple{Simple: core.SimpleType_NONE}, @@ -195,14 +239,7 @@ func literalTypeForLiterals(literals []*core.Literal) *core.LiteralType { return 0 }) - - return &core.LiteralType{ - Type: &core.LiteralType_UnionType{ - UnionType: &core.UnionType{ - Variants: innerType, - }, - }, - } + return buildMultipleTypeUnion(innerType) } // LiteralTypeForLiteral gets LiteralType for literal, nil if the value of literal is unknown, or type collection/map of diff --git a/flytepropeller/pkg/compiler/validators/utils_test.go b/flytepropeller/pkg/compiler/validators/utils_test.go index a5e5d3f23c..29daed99ae 100644 --- a/flytepropeller/pkg/compiler/validators/utils_test.go +++ b/flytepropeller/pkg/compiler/validators/utils_test.go @@ -1,6 +1,7 @@ package validators import ( + "github.com/golang/protobuf/proto" "testing" "github.com/stretchr/testify/assert" @@ -51,6 +52,290 @@ func TestLiteralTypeForLiterals(t *testing.T) { assert.Equal(t, core.SimpleType_INTEGER.String(), lt.GetUnionType().Variants[0].GetSimple().String()) assert.Equal(t, core.SimpleType_STRING.String(), lt.GetUnionType().Variants[1].GetSimple().String()) }) + + t.Run("list with mixed types", func(t *testing.T) { + literals := &core.Literal{ + Value: &core.Literal_Collection{ + Collection: &core.LiteralCollection{ + Literals: []*core.Literal{ + { + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Union{ + Union: &core.Union{ + Value: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: &core.Primitive{ + Value: &core.Primitive_Integer{ + Integer: 1, + }, + }, + }, + }, + }, + }, + Type: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_INTEGER, + }, + }, + }, + }, + }, + }, + }, + { + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Union{ + Union: &core.Union{ + Value: &core.Literal{ + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: &core.Primitive{ + Value: &core.Primitive_StringValue{ + StringValue: "foo", + }, + }, + }, + }, + }, + }, + Type: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_STRING, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + lt := LiteralTypeForLiteral(literals) + + expectedLt := &core.LiteralType{ + Type: &core.LiteralType_CollectionType{ + CollectionType: &core.LiteralType{ + Type: &core.LiteralType_UnionType{ + UnionType: &core.UnionType{ + Variants: []*core.LiteralType{ + { + Type: &core.LiteralType_UnionType{ + UnionType: &core.UnionType{ + Variants: []*core.LiteralType{ + { + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_INTEGER, + }, + }, + }, + }, + }, + }, + { + Type: &core.LiteralType_UnionType{ + UnionType: &core.UnionType{ + Variants: []*core.LiteralType{ + { + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_STRING, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + assert.True(t, proto.Equal(expectedLt, lt)) + }) + + t.Run("nested lists with empty list", func(t *testing.T) { + literals := &core.Literal{ + Value: &core.Literal_Collection{ + Collection: &core.LiteralCollection{ + Literals: []*core.Literal{ + { + Value: &core.Literal_Collection{ + Collection: &core.LiteralCollection{ + Literals: []*core.Literal{ + { + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Primitive{ + Primitive: &core.Primitive{ + Value: &core.Primitive_StringValue{ + StringValue: "foo", + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + { + Value: &core.Literal_Collection{ + Collection: &core.LiteralCollection{}, + }, + }, + }, + }, + }, + } + + lt := LiteralTypeForLiteral(literals) + + expectedLt := &core.LiteralType{ + Type: &core.LiteralType_CollectionType{ + CollectionType: &core.LiteralType{ + Type: &core.LiteralType_CollectionType{ + CollectionType: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_STRING, + }, + }, + }, + }, + }, + } + + assert.True(t, proto.Equal(expectedLt, lt)) + }) + + t.Run("nested Lists with different types", func(t *testing.T) { + literals := &core.Literal{ + Value: &core.Literal_Collection{ + Collection: &core.LiteralCollection{ + Literals: []*core.Literal{ + { + Value: &core.Literal_Collection{ + Collection: &core.LiteralCollection{ + Literals: []*core.Literal{ + { + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Union{ + Union: &core.Union{ + Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_INTEGER}}, + Value: &core.Literal{Value: &core.Literal_Scalar{Scalar: &core.Scalar{ + Value: &core.Scalar_Primitive{Primitive: &core.Primitive{Value: &core.Primitive_Integer{Integer: 1}}}}}}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + { + Value: &core.Literal_Collection{ + Collection: &core.LiteralCollection{ + Literals: []*core.Literal{ + { + Value: &core.Literal_Scalar{ + Scalar: &core.Scalar{ + Value: &core.Scalar_Union{ + Union: &core.Union{ + Type: &core.LiteralType{Type: &core.LiteralType_Simple{Simple: core.SimpleType_STRING}}, + Value: &core.Literal{Value: &core.Literal_Scalar{Scalar: &core.Scalar{ + Value: &core.Scalar_Primitive{Primitive: &core.Primitive{Value: &core.Primitive_StringValue{StringValue: "foo"}}}}}}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + expectedLt := &core.LiteralType{ + Type: &core.LiteralType_CollectionType{ + CollectionType: &core.LiteralType{ + Type: &core.LiteralType_CollectionType{ + CollectionType: &core.LiteralType{ + Type: &core.LiteralType_UnionType{ + UnionType: &core.UnionType{ + Variants: []*core.LiteralType{ + { + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_INTEGER, + }, + }, + { + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_STRING, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + lt := LiteralTypeForLiteral(literals) + + assert.True(t, proto.Equal(expectedLt, lt)) + }) + + t.Run("empty nested listed", func(t *testing.T) { + literals := &core.Literal{ + Value: &core.Literal_Collection{ + Collection: &core.LiteralCollection{ + Literals: []*core.Literal{ + { + Value: &core.Literal_Collection{ + Collection: &core.LiteralCollection{}, + }, + }, + }, + }, + }, + } + + expectedLt := &core.LiteralType{ + Type: &core.LiteralType_CollectionType{ + CollectionType: &core.LiteralType{ + Type: &core.LiteralType_CollectionType{ + CollectionType: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_NONE, + }, + }, + }, + }, + }, + } + + lt := LiteralTypeForLiteral(literals) + + assert.True(t, proto.Equal(expectedLt, lt)) + }) + } func TestJoinVariableMapsUniqueKeys(t *testing.T) {