From cac222566cea4673975946d9e5f555accf9e30fa Mon Sep 17 00:00:00 2001 From: Ian Wahbe Date: Thu, 14 Dec 2023 16:34:23 -0800 Subject: [PATCH] Support .Elem() for custom types (#1571) This is necessary to support https://github.com/pulumi/pulumi-aws/issues/3110. --- pf/internal/schemashim/attr_schema.go | 34 ++++++---------------- pf/internal/schemashim/custom_type_test.go | 28 ++++++++++++++++++ 2 files changed, 37 insertions(+), 25 deletions(-) diff --git a/pf/internal/schemashim/attr_schema.go b/pf/internal/schemashim/attr_schema.go index 708be9655..9f697482e 100644 --- a/pf/internal/schemashim/attr_schema.go +++ b/pf/internal/schemashim/attr_schema.go @@ -15,11 +15,9 @@ package schemashim import ( - "fmt" bridge "github.com/pulumi/pulumi-terraform-bridge/v3/pkg/tfbridge" pfattr "github.com/hashicorp/terraform-plugin-framework/attr" - "github.com/hashicorp/terraform-plugin-framework/types" "github.com/hashicorp/terraform-plugin-framework/types/basetypes" "github.com/pulumi/pulumi-terraform-bridge/pf/internal/pfutils" @@ -86,39 +84,25 @@ func (*attrSchema) StateFunc() shim.SchemaStateFunc { // Needs to return a shim.Schema, a shim.Resource, or nil. func (s *attrSchema) Elem() interface{} { - t := s.attr.GetType() - + switch t := s.attr.GetType().(type) { // The ObjectType can be triggered through tfsdk.SingleNestedAttributes. Logically it defines an attribute with // a type that is an Object type. To encode the schema of the Object type in a way the shim layer understands, // Elem() needes to return a Resource value. // // See also: documentation on shim.Schema.Elem(). - if tt, ok := t.(basetypes.ObjectTypable); ok { - var res shim.Resource = newObjectPseudoResource(tt, s.attr.Nested(), nil) + case basetypes.ObjectTypable: + var res shim.Resource = newObjectPseudoResource(t, s.attr.Nested(), nil) return res - } - if tt, ok := t.(pfattr.TypeWithElementTypes); ok { - var res shim.Resource = newTuplePseudoResource(tt) + case pfattr.TypeWithElementTypes: + var res shim.Resource = newTuplePseudoResource(t) return res - } + case pfattr.TypeWithElementType: + return shim.Schema(newTypeSchema(t.ElementType(), s.attr.Nested())) - // Anything else that does not have an ElementType can be skipped. - if _, ok := t.(pfattr.TypeWithElementType); !ok { - return nil - } - - var schema shim.Schema - switch tt := t.(type) { - case types.MapType: - schema = newTypeSchema(tt.ElemType, s.attr.Nested()) - case types.ListType: - schema = newTypeSchema(tt.ElemType, s.attr.Nested()) - case types.SetType: - schema = newTypeSchema(tt.ElemType, s.attr.Nested()) + // t does not support any kind of element type. default: - panic(fmt.Errorf("This Elem() case is not yet supported: %v", t)) + return nil } - return schema } func (*attrSchema) MaxItems() int { diff --git a/pf/internal/schemashim/custom_type_test.go b/pf/internal/schemashim/custom_type_test.go index 8abbbae3a..7b6ea2ef6 100644 --- a/pf/internal/schemashim/custom_type_test.go +++ b/pf/internal/schemashim/custom_type_test.go @@ -89,6 +89,34 @@ func TestCustomListType(t *testing.T) { assert.Equal(t, shim.TypeString, create.Type()) } +func TestCustomListAttribute(t *testing.T) { + ctx := context.Background() + + raw := schema.ListNestedAttribute{ + CustomType: newListNestedObjectTypeOf[searchFilterModel](ctx, types.ObjectType{ + AttrTypes: map[string]attr.Type{ + "filter_string": basetypes.StringType{}, + }, + }), + NestedObject: schema.NestedAttributeObject{ + Attributes: map[string]schema.Attribute{ + "filter_string": schema.StringAttribute{ + Required: true, + }, + }, + }, + } + + shimmed := &attrSchema{"key", pfutils.FromAttrLike(raw)} + assert.Equal(t, shim.TypeList, shimmed.Type()) + assert.NotNil(t, shimmed.Elem()) + _, isPseudoResource := shimmed.Elem().(shim.Schema) + assert.Truef(t, isPseudoResource, "expected shim.Elem() to be of type shim.Resource, encoding an object type") + + create := shimmed.Elem().(shim.Schema).Elem().(shim.Resource).Schema().Get("filter_string") + assert.Equal(t, shim.TypeString, create.Type()) +} + func TestCustomSetType(t *testing.T) { ctx := context.Background()