Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for tuple enum variants #18

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gen_testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ func genTestWithComplexEnum(tFunGroup *Group, insExportedName string, instructio

enumName := arg.Type.GetIdlTypeDefined().Defined
interfaceType := idl.Types.GetByName(enumName)
for _, variant := range interfaceType.Type.Variants {
for _, variant := range *interfaceType.Type.Variants {

enumBlock.BlockFunc(func(variantBlock *Group) {
variantBlock.Id("params").Op(":=").New(Id(insExportedName))
Expand Down
32 changes: 23 additions & 9 deletions generator.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package main

import (
"fmt"
. "github.com/dave/jennifer/jen"
"github.com/davecgh/go-spew/spew"
bin "github.com/gagliardetto/binary"
Expand Down Expand Up @@ -111,7 +112,7 @@ func genTypeName(idlTypeEnv IdlType) Code {
case idlTypeEnv.IsArray():
{
arr := idlTypeEnv.GetArray()
st.Index(Id(Itoa(arr.Num))).Add(genTypeName(arr.Thing))
st.Index(Id(Itoa(arr.Num))).Add(genTypeName(arr.Elem))
}
default:
panic(spew.Sdump(idlTypeEnv))
Expand Down Expand Up @@ -252,7 +253,7 @@ func genTypeDef(idl *IDL, withDiscriminator bool, def IdlTypeDef) Code {
if def.Type.Variants.IsSimpleEnum() {
code.Type().Id(enumTypeName).Qual(PkgDfuseBinary, "BorshEnum")
code.Line().Const().Parens(DoGroup(func(gr *Group) {
for variantIndex, variant := range def.Type.Variants {
for variantIndex, variant := range *def.Type.Variants {

for docIndex, doc := range variant.Docs {
if docIndex == 0 {
Expand All @@ -277,7 +278,7 @@ func genTypeDef(idl *IDL, withDiscriminator bool, def IdlTypeDef) Code {
Params(String()).
BlockFunc(func(body *Group) {
body.Switch(Id("value")).BlockFunc(func(switchBlock *Group) {
for _, variant := range def.Type.Variants {
for _, variant := range *def.Type.Variants {
switchBlock.Case(Id(formatSimpleEnumVariantName(variant.Name, enumTypeName))).Line().Return(Lit(variant.Name))
}
switchBlock.Default().Line().Return(Lit(""))
Expand All @@ -302,13 +303,13 @@ func genTypeDef(idl *IDL, withDiscriminator bool, def IdlTypeDef) Code {
"borsh_enum": "true",
})

for _, variant := range def.Type.Variants {
for _, variant := range *def.Type.Variants {
structGroup.Id(ToCamel(variant.Name)).Id(formatComplexEnumVariantTypeName(enumTypeName, variant.Name))
}
},
).Line().Line()

for _, variant := range def.Type.Variants {
for _, variant := range *def.Type.Variants {
// Name of the variant type if the enum is a complex enum (i.e. enum variants are inline structs):
variantTypeNameComplex := formatComplexEnumVariantTypeName(enumTypeName, variant.Name)

Expand All @@ -333,8 +334,21 @@ func genTypeDef(idl *IDL, withDiscriminator bool, def IdlTypeDef) Code {
}())
}
default:
// TODO: handle tuples
panic("not handled: " + Sdump(variant.Fields))
for i, variantTupleItem := range *variant.Fields.IdlEnumFieldsTuple {
variantField := IdlField{
Name: fmt.Sprintf("Elem_%d", i),
Type: variantTupleItem,
}
structGroup.Add(genField(variantField, variantField.Type.IsIdlTypeOption())).
Add(func() Code {
if variantField.Type.IsIdlTypeOption() {
return Tag(map[string]string{
"bin": "optional",
})
}
return nil
}())
}
}
},
).Line().Line()
Expand Down Expand Up @@ -485,7 +499,7 @@ func genMarshalWithEncoder_struct(
BlockFunc(func(switchGroup *Group) {
// TODO: maybe it's from idl.Accounts ???
interfaceType := idl.Types.GetByName(enumTypeName)
for variantIndex, variant := range interfaceType.Type.Variants {
for variantIndex, variant := range *interfaceType.Type.Variants {
variantTypeNameStruct := formatComplexEnumVariantTypeName(enumTypeName, variant.Name)

switchGroup.Case(Op("*").Id(variantTypeNameStruct)).
Expand Down Expand Up @@ -629,7 +643,7 @@ func genUnmarshalWithDecoder_struct(
argBody.Switch(Id("tmp").Dot("Enum")).
BlockFunc(func(switchGroup *Group) {
interfaceType := idl.Types.GetByName(enumName)
for variantIndex, variant := range interfaceType.Type.Variants {
for variantIndex, variant := range *interfaceType.Type.Variants {
variantTypeNameComplex := formatComplexEnumVariantTypeName(enumName, variant.Name)

if variant.IsUint8() {
Expand Down
104 changes: 41 additions & 63 deletions idl.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,13 @@ type IdlState struct {

type IdlStateMethod = IdlInstruction

// type IdlAccountItem = IdlAccount | IdlAccounts;
// IdlAccountItem is of type IdlAccountItem = IdlAccount | IdlAccounts;
type IdlAccountItem struct {
IdlAccount *IdlAccount
IdlAccounts *IdlAccounts
}

func (item IdlAccountItem) Walk(
func (item *IdlAccountItem) Walk(
parentGroupPath string,
previousIndex *int,
parentGroup *IdlAccounts,
Expand Down Expand Up @@ -146,15 +146,15 @@ func (item IdlAccountItem) Walk(
}

// TODO: verify with examples
func (env *IdlAccountItem) UnmarshalJSON(data []byte) error {
func (item *IdlAccountItem) UnmarshalJSON(data []byte) error {

var temp interface{}
if err := json.Unmarshal(data, &temp); err != nil {
return err
}

if temp == nil {
return fmt.Errorf("envelope is nil: %v", env)
return fmt.Errorf("envelope is nil: %v", item)
}

switch v := temp.(type) {
Expand All @@ -169,14 +169,14 @@ func (env *IdlAccountItem) UnmarshalJSON(data []byte) error {

// Multiple accounts:
if _, ok := v["accounts"]; ok {
if err := TranscodeJSON(temp, &env.IdlAccounts); err != nil {
if err := TranscodeJSON(temp, &item.IdlAccounts); err != nil {
return err
}
}
// Single account:
// TODO: check both isMut and isSigner
if _, ok := v["isMut"]; ok {
if err := TranscodeJSON(temp, &env.IdlAccount); err != nil {
if err := TranscodeJSON(temp, &item.IdlAccount); err != nil {
return err
}
}
Expand All @@ -198,7 +198,7 @@ type IdlAccount struct {
Optional bool `json:"optional"` // @custom
}

// A nested/recursive version of IdlAccount.
// IdlAccounts is a nested/recursive version of IdlAccount.
type IdlAccounts struct {
Name string `json:"name"`
Docs []string `json:"docs"` // @custom
Expand Down Expand Up @@ -247,15 +247,15 @@ type IdlTypeOption struct {
Option IdlType `json:"option"`
}

// User defined type.
// IdlTypeDefined is a User defined type.
type IdlTypeDefined struct {
Defined string `json:"defined"`
}

// Wrapper type:
// IdlTypeArray is a Wrapper type:
type IdlTypeArray struct {
Thing IdlType
Num int
Elem IdlType
Num int
}

func (env *IdlType) UnmarshalJSON(data []byte) error {
Expand Down Expand Up @@ -314,7 +314,7 @@ func (env *IdlType) UnmarshalJSON(data []byte) error {
panic(Sf("array is not of expected length:\n%s", spew.Sdump(got)))
}
var target IdlTypeArray
if err := TranscodeJSON(arrVal[0], &target.Thing); err != nil {
if err := TranscodeJSON(arrVal[0], &target.Elem); err != nil {
return err
}

Expand All @@ -331,7 +331,7 @@ func (env *IdlType) UnmarshalJSON(data []byte) error {
return nil
}

// Wrapper type:
// IdlType is a Wrapper type:
type IdlType struct {
asString IdlTypeAsString
asIdlTypeVec *IdlTypeVec
Expand Down Expand Up @@ -376,6 +376,13 @@ func (env *IdlType) GetArray() *IdlTypeArray {
type IdlTypeDef struct {
Name string `json:"name"`
Type IdlTypeDefTy `json:"type"`
Docs []string `json:"docs"`
}

type IdlTypeDefTy struct {
Kind IdlTypeDefTyKind `json:"kind"`
Fields *IdlStructFieldSlice `json:"fields,omitempty"`
Variants *IdlEnumVariantSlice `json:"variants,omitempty"`
}

type IdlTypeDefTyKind string
Expand All @@ -385,26 +392,7 @@ const (
IdlTypeDefTyKindEnum IdlTypeDefTyKind = "enum"
)

// TODO:
type IdlTypeDefTyStruct struct {
Kind IdlTypeDefTyKind `json:"kind"` // == "struct"

Fields *IdlTypeDefStruct `json:"fields,omitempty"`
}

// TODO:
type IdlTypeDefTyEnum struct {
Kind IdlTypeDefTyKind `json:"kind"` // == "enum"

Variants IdlEnumVariantSlice `json:"variants,omitempty"`
}

type IdlTypeDefTy struct {
Kind IdlTypeDefTyKind `json:"kind"`

Fields *IdlTypeDefStruct `json:"fields,omitempty"`
Variants IdlEnumVariantSlice `json:"variants,omitempty"`
}
type IdlStructFieldSlice []IdlField

type IdlEnumVariantSlice []IdlEnumVariant

Expand Down Expand Up @@ -447,46 +435,36 @@ type IdlEnumFieldsTuple []IdlType

// TODO: verify with examples
func (env *IdlEnumFields) UnmarshalJSON(data []byte) error {

var temp interface{}
if err := json.Unmarshal(data, &temp); err != nil {
var tmp interface{}
if err := json.Unmarshal(data, &tmp); err != nil {
return err
}

if temp == nil {
if tmp == nil {
return fmt.Errorf("envelope is nil: %v", env)
}

switch v := temp.(type) {
case []interface{}:
{
// Ln(LimeBG("::IdlEnumFields"))
// spew.Dump(v)

if len(v) == 0 {
return nil
}

firstItem := v[0]

if _, ok := firstItem.(map[string]interface{})["name"]; ok {
// TODO:
// If has `name` field, then it's most likely a IdlEnumFieldsNamed.
if err := TranscodeJSON(temp, &env.IdlEnumFieldsNamed); err != nil {
return err
}
} else {
if err := TranscodeJSON(temp, &env.IdlEnumFieldsTuple); err != nil {
return err
}
}
fields, ok := tmp.([]interface{})
if !ok {
return fmt.Errorf("fields must be a slice")
}

// panic(Sf("what is this?:\n%s", spew.Sdump(temp)))
if len(fields) == 0 {
return nil
}
if m, ok := fields[0].(map[string]interface{}); ok && m["name"] != nil {
// If has `name` field, then it's most likely a IdlEnumFieldsNamed.
if err := TranscodeJSON(tmp, &env.IdlEnumFieldsNamed); err != nil {
return err
}
} else {
if err := TranscodeJSON(tmp, &env.IdlEnumFieldsTuple); err != nil {
return err
}
default:
return fmt.Errorf("Unknown kind: %s", spew.Sdump(temp))
}

// panic(Sf("what is this?:\n%s", spew.Sdump(temp)))

return nil
}

Expand Down