diff --git a/cty/function/stdlib/collection.go b/cty/function/stdlib/collection.go index 1816bb9c..12b2914f 100644 --- a/cty/function/stdlib/collection.go +++ b/cty/function/stdlib/collection.go @@ -397,9 +397,15 @@ var DistinctFunc = function.New(&function.Spec{ } var list []cty.Value + buckets := make(map[int][]cty.Value) for it := listVal.ElementIterator(); it.Next(); { _, v := it.Element() - list, err = appendIfMissing(list, v) + h := v.Hash() + if _, ok := buckets[h]; !ok { + buckets[h] = make([]cty.Value, 0) + } + + list, buckets[h], err = appendIfMissing(list, buckets[h], v) if err != nil { return cty.NilVal, err } @@ -1429,18 +1435,18 @@ var ZipmapFunc = function.New(&function.Spec{ }, }) -// helper function to add an element to a list, if it does not already exist -func appendIfMissing(slice []cty.Value, element cty.Value) ([]cty.Value, error) { - for _, ele := range slice { +// helper function to add an element to a list, if it does not already exist in a sublist +func appendIfMissing(slice []cty.Value, subslice []cty.Value, element cty.Value) ([]cty.Value, []cty.Value, error) { + for _, ele := range subslice { eq, err := Equal(ele, element) if err != nil { - return slice, err + return slice, subslice, err } if eq.True() { - return slice, nil + return slice, subslice, nil } } - return append(slice, element), nil + return append(slice, element), append(subslice, element), nil } // HasIndex determines whether the given collection can be indexed with the diff --git a/cty/function/stdlib/collection_test.go b/cty/function/stdlib/collection_test.go index b6b9b6b8..55ccc815 100644 --- a/cty/function/stdlib/collection_test.go +++ b/cty/function/stdlib/collection_test.go @@ -2800,3 +2800,43 @@ func TestSlice(t *testing.T) { }) } } + +func TestDistinct(t *testing.T) { + tests := []struct { + Key string + Input cty.Value + Want cty.Value + }{ + { + "first", + cty.ListVal([]cty.Value{ + cty.StringVal("a"), + cty.StringVal("b"), + cty.StringVal("c"), + cty.StringVal("b"), + cty.StringVal("c"), + cty.StringVal("d"), + }), + cty.ListVal([]cty.Value{ + cty.StringVal("a"), + cty.StringVal("b"), + cty.StringVal("c"), + cty.StringVal("d"), + }), + }, + } + + for _, test := range tests { + t.Run(test.Key, func(t *testing.T) { + got, err := Distinct(test.Input) + + if err != nil { + t.Fatalf("unexpected error: %s", err) + } + + if !got.RawEquals(test.Want) { + t.Errorf("wrong result\ngot: %#v\nwant: %#v", got, test.Want) + } + }) + } +}