From fd42b56159b9c39ad817c95e135a07e58052dd59 Mon Sep 17 00:00:00 2001 From: Laurent Zeimes Date: Sun, 5 Nov 2023 12:13:54 +0100 Subject: [PATCH] feat: add support for tuple enum variants --- gen_testing.go | 2 +- generator.go | 32 ++++++++++----- idl.go | 104 +++++++++++++++++++------------------------------ 3 files changed, 65 insertions(+), 73 deletions(-) diff --git a/gen_testing.go b/gen_testing.go index b49c210..4765348 100644 --- a/gen_testing.go +++ b/gen_testing.go @@ -177,7 +177,7 @@ func genTestWithComplexEnum(tFunGroup *Group, insExportedName string, instructio enumName := arg.Type.GetIdlTypeDefined().Defined interfaceType := idl.Types.GetByName(enumName) - for _, variant := range interfaceType.Type.Variants { + for _, variant := range *interfaceType.Type.Variants { enumBlock.BlockFunc(func(variantBlock *Group) { variantBlock.Id("params").Op(":=").New(Id(insExportedName)) diff --git a/generator.go b/generator.go index df5c1a2..c4c5f47 100644 --- a/generator.go +++ b/generator.go @@ -1,6 +1,7 @@ package main import ( + "fmt" . "github.com/dave/jennifer/jen" "github.com/davecgh/go-spew/spew" bin "github.com/gagliardetto/binary" @@ -111,7 +112,7 @@ func genTypeName(idlTypeEnv IdlType) Code { case idlTypeEnv.IsArray(): { arr := idlTypeEnv.GetArray() - st.Index(Id(Itoa(arr.Num))).Add(genTypeName(arr.Thing)) + st.Index(Id(Itoa(arr.Num))).Add(genTypeName(arr.Elem)) } default: panic(spew.Sdump(idlTypeEnv)) @@ -252,7 +253,7 @@ func genTypeDef(idl *IDL, withDiscriminator bool, def IdlTypeDef) Code { if def.Type.Variants.IsSimpleEnum() { code.Type().Id(enumTypeName).Qual(PkgDfuseBinary, "BorshEnum") code.Line().Const().Parens(DoGroup(func(gr *Group) { - for variantIndex, variant := range def.Type.Variants { + for variantIndex, variant := range *def.Type.Variants { for docIndex, doc := range variant.Docs { if docIndex == 0 { @@ -277,7 +278,7 @@ func genTypeDef(idl *IDL, withDiscriminator bool, def IdlTypeDef) Code { Params(String()). BlockFunc(func(body *Group) { body.Switch(Id("value")).BlockFunc(func(switchBlock *Group) { - for _, variant := range def.Type.Variants { + for _, variant := range *def.Type.Variants { switchBlock.Case(Id(formatSimpleEnumVariantName(variant.Name, enumTypeName))).Line().Return(Lit(variant.Name)) } switchBlock.Default().Line().Return(Lit("")) @@ -302,13 +303,13 @@ func genTypeDef(idl *IDL, withDiscriminator bool, def IdlTypeDef) Code { "borsh_enum": "true", }) - for _, variant := range def.Type.Variants { + for _, variant := range *def.Type.Variants { structGroup.Id(ToCamel(variant.Name)).Id(formatComplexEnumVariantTypeName(enumTypeName, variant.Name)) } }, ).Line().Line() - for _, variant := range def.Type.Variants { + for _, variant := range *def.Type.Variants { // Name of the variant type if the enum is a complex enum (i.e. enum variants are inline structs): variantTypeNameComplex := formatComplexEnumVariantTypeName(enumTypeName, variant.Name) @@ -333,8 +334,21 @@ func genTypeDef(idl *IDL, withDiscriminator bool, def IdlTypeDef) Code { }()) } default: - // TODO: handle tuples - panic("not handled: " + Sdump(variant.Fields)) + for i, variantTupleItem := range *variant.Fields.IdlEnumFieldsTuple { + variantField := IdlField{ + Name: fmt.Sprintf("Elem_%d", i), + Type: variantTupleItem, + } + structGroup.Add(genField(variantField, variantField.Type.IsIdlTypeOption())). + Add(func() Code { + if variantField.Type.IsIdlTypeOption() { + return Tag(map[string]string{ + "bin": "optional", + }) + } + return nil + }()) + } } }, ).Line().Line() @@ -485,7 +499,7 @@ func genMarshalWithEncoder_struct( BlockFunc(func(switchGroup *Group) { // TODO: maybe it's from idl.Accounts ??? interfaceType := idl.Types.GetByName(enumTypeName) - for variantIndex, variant := range interfaceType.Type.Variants { + for variantIndex, variant := range *interfaceType.Type.Variants { variantTypeNameStruct := formatComplexEnumVariantTypeName(enumTypeName, variant.Name) switchGroup.Case(Op("*").Id(variantTypeNameStruct)). @@ -629,7 +643,7 @@ func genUnmarshalWithDecoder_struct( argBody.Switch(Id("tmp").Dot("Enum")). BlockFunc(func(switchGroup *Group) { interfaceType := idl.Types.GetByName(enumName) - for variantIndex, variant := range interfaceType.Type.Variants { + for variantIndex, variant := range *interfaceType.Type.Variants { variantTypeNameComplex := formatComplexEnumVariantTypeName(enumName, variant.Name) if variant.IsUint8() { diff --git a/idl.go b/idl.go index 2f54e46..b2655d0 100644 --- a/idl.go +++ b/idl.go @@ -110,13 +110,13 @@ type IdlState struct { type IdlStateMethod = IdlInstruction -// type IdlAccountItem = IdlAccount | IdlAccounts; +// IdlAccountItem is of type IdlAccountItem = IdlAccount | IdlAccounts; type IdlAccountItem struct { IdlAccount *IdlAccount IdlAccounts *IdlAccounts } -func (item IdlAccountItem) Walk( +func (item *IdlAccountItem) Walk( parentGroupPath string, previousIndex *int, parentGroup *IdlAccounts, @@ -146,7 +146,7 @@ func (item IdlAccountItem) Walk( } // TODO: verify with examples -func (env *IdlAccountItem) UnmarshalJSON(data []byte) error { +func (item *IdlAccountItem) UnmarshalJSON(data []byte) error { var temp interface{} if err := json.Unmarshal(data, &temp); err != nil { @@ -154,7 +154,7 @@ func (env *IdlAccountItem) UnmarshalJSON(data []byte) error { } if temp == nil { - return fmt.Errorf("envelope is nil: %v", env) + return fmt.Errorf("envelope is nil: %v", item) } switch v := temp.(type) { @@ -169,14 +169,14 @@ func (env *IdlAccountItem) UnmarshalJSON(data []byte) error { // Multiple accounts: if _, ok := v["accounts"]; ok { - if err := TranscodeJSON(temp, &env.IdlAccounts); err != nil { + if err := TranscodeJSON(temp, &item.IdlAccounts); err != nil { return err } } // Single account: // TODO: check both isMut and isSigner if _, ok := v["isMut"]; ok { - if err := TranscodeJSON(temp, &env.IdlAccount); err != nil { + if err := TranscodeJSON(temp, &item.IdlAccount); err != nil { return err } } @@ -198,7 +198,7 @@ type IdlAccount struct { Optional bool `json:"optional"` // @custom } -// A nested/recursive version of IdlAccount. +// IdlAccounts is a nested/recursive version of IdlAccount. type IdlAccounts struct { Name string `json:"name"` Docs []string `json:"docs"` // @custom @@ -247,15 +247,15 @@ type IdlTypeOption struct { Option IdlType `json:"option"` } -// User defined type. +// IdlTypeDefined is a User defined type. type IdlTypeDefined struct { Defined string `json:"defined"` } -// Wrapper type: +// IdlTypeArray is a Wrapper type: type IdlTypeArray struct { - Thing IdlType - Num int + Elem IdlType + Num int } func (env *IdlType) UnmarshalJSON(data []byte) error { @@ -314,7 +314,7 @@ func (env *IdlType) UnmarshalJSON(data []byte) error { panic(Sf("array is not of expected length:\n%s", spew.Sdump(got))) } var target IdlTypeArray - if err := TranscodeJSON(arrVal[0], &target.Thing); err != nil { + if err := TranscodeJSON(arrVal[0], &target.Elem); err != nil { return err } @@ -331,7 +331,7 @@ func (env *IdlType) UnmarshalJSON(data []byte) error { return nil } -// Wrapper type: +// IdlType is a Wrapper type: type IdlType struct { asString IdlTypeAsString asIdlTypeVec *IdlTypeVec @@ -376,6 +376,13 @@ func (env *IdlType) GetArray() *IdlTypeArray { type IdlTypeDef struct { Name string `json:"name"` Type IdlTypeDefTy `json:"type"` + Docs []string `json:"docs"` +} + +type IdlTypeDefTy struct { + Kind IdlTypeDefTyKind `json:"kind"` + Fields *IdlStructFieldSlice `json:"fields,omitempty"` + Variants *IdlEnumVariantSlice `json:"variants,omitempty"` } type IdlTypeDefTyKind string @@ -385,26 +392,7 @@ const ( IdlTypeDefTyKindEnum IdlTypeDefTyKind = "enum" ) -// TODO: -type IdlTypeDefTyStruct struct { - Kind IdlTypeDefTyKind `json:"kind"` // == "struct" - - Fields *IdlTypeDefStruct `json:"fields,omitempty"` -} - -// TODO: -type IdlTypeDefTyEnum struct { - Kind IdlTypeDefTyKind `json:"kind"` // == "enum" - - Variants IdlEnumVariantSlice `json:"variants,omitempty"` -} - -type IdlTypeDefTy struct { - Kind IdlTypeDefTyKind `json:"kind"` - - Fields *IdlTypeDefStruct `json:"fields,omitempty"` - Variants IdlEnumVariantSlice `json:"variants,omitempty"` -} +type IdlStructFieldSlice []IdlField type IdlEnumVariantSlice []IdlEnumVariant @@ -447,46 +435,36 @@ type IdlEnumFieldsTuple []IdlType // TODO: verify with examples func (env *IdlEnumFields) UnmarshalJSON(data []byte) error { - - var temp interface{} - if err := json.Unmarshal(data, &temp); err != nil { + var tmp interface{} + if err := json.Unmarshal(data, &tmp); err != nil { return err } - if temp == nil { + if tmp == nil { return fmt.Errorf("envelope is nil: %v", env) } - switch v := temp.(type) { - case []interface{}: - { - // Ln(LimeBG("::IdlEnumFields")) - // spew.Dump(v) - - if len(v) == 0 { - return nil - } - - firstItem := v[0] - - if _, ok := firstItem.(map[string]interface{})["name"]; ok { - // TODO: - // If has `name` field, then it's most likely a IdlEnumFieldsNamed. - if err := TranscodeJSON(temp, &env.IdlEnumFieldsNamed); err != nil { - return err - } - } else { - if err := TranscodeJSON(temp, &env.IdlEnumFieldsTuple); err != nil { - return err - } - } + fields, ok := tmp.([]interface{}) + if !ok { + return fmt.Errorf("fields must be a slice") + } - // panic(Sf("what is this?:\n%s", spew.Sdump(temp))) + if len(fields) == 0 { + return nil + } + if m, ok := fields[0].(map[string]interface{}); ok && m["name"] != nil { + // If has `name` field, then it's most likely a IdlEnumFieldsNamed. + if err := TranscodeJSON(tmp, &env.IdlEnumFieldsNamed); err != nil { + return err + } + } else { + if err := TranscodeJSON(tmp, &env.IdlEnumFieldsTuple); err != nil { + return err } - default: - return fmt.Errorf("Unknown kind: %s", spew.Sdump(temp)) } + // panic(Sf("what is this?:\n%s", spew.Sdump(temp))) + return nil }