Skip to content

Commit

Permalink
feat: check default value type
Browse files Browse the repository at this point in the history
  • Loading branch information
jayantxie committed Aug 11, 2022
1 parent e1caf7d commit 3790b20
Show file tree
Hide file tree
Showing 3 changed files with 181 additions and 27 deletions.
192 changes: 165 additions & 27 deletions generator/golang/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,20 +157,29 @@ func (r *Resolver) getContainerTypeName(g *Scope, t *parser.Type) (name string,
// getIDValue returns the literal representation of a const value.
// The extra must be associated with g and from a const value that has
// type parser.ConstType_ConstIdentifier.
func (r *Resolver) getIDValue(g *Scope, extra *parser.ConstValueExtra) (v string, ok bool) {
func (r *Resolver) getIDValue(g *Scope, extra *parser.ConstValueExtra) (v string, t *parser.Type, ok bool) {
if extra.Index == -1 {
if extra.IsEnum {
enum, ok := g.ast.GetEnum(extra.Sel)
if !ok {
return "", false
return "", t, false
}
if en := g.Enum(enum.Name); en != nil {
if ev := en.Value(extra.Name); ev != nil {
v = ev.GoName().String()
t = &parser.Type{
Name: enum.Name,
Category: parser.Category_Enum,
}
}
}
} else {
v = g.globals.Get(extra.Name)
con, ok := g.ast.GetConstant(extra.Name)
if !ok {
return "", t, false
}
t = con.Type
}
} else {
g = g.includes[extra.Index].Scope
Expand All @@ -186,7 +195,7 @@ func (r *Resolver) getIDValue(g *Scope, extra *parser.ConstValueExtra) (v string
pkg := r.root.includeIDL(r.util, g.ast)
v = pkg + "." + v
}
return v, v != ""
return v, t, v != ""
}

// ResolveConst returns the initialization code for a constant or a default value.
Expand Down Expand Up @@ -239,10 +248,14 @@ func (r *Resolver) onBool(g *Scope, name string, t *parser.Type, v *parser.Const
return s, nil
}

if val, ok := r.getIDValue(g, v.Extra); ok {
return val, nil
val, cate, ok := r.getIDValue(g, v.Extra)
if !ok {
return "", fmt.Errorf("undefined value: %q", s)
}
return "", fmt.Errorf("undefined value: %q", s)
if err := r.typeMatch(t, cate, name); err != nil {
return "", err
}
return val, nil
}
return "", errTypeMissMatch(name, t, v)
}
Expand All @@ -260,12 +273,16 @@ func (r *Resolver) onInt(g *Scope, name string, t *parser.Type, v *parser.ConstV
if s == "false" {
return "0", nil
}
if val, ok := r.getIDValue(g, v.Extra); ok {
goType, _ := r.getTypeName(g, t)
val = fmt.Sprintf("%s(%s)", goType, val)
return val, nil
val, cate, ok := r.getIDValue(g, v.Extra)
if !ok {
return "", fmt.Errorf("undefined value: %q", s)
}
return "", fmt.Errorf("undefined value: %q", s)
if err := r.typeMatch(t, cate, name); err != nil {
return "", err
}
goType, _ := r.getTypeName(g, t)
val = fmt.Sprintf("%s(%s)", goType, val)
return val, nil
}
return "", errTypeMissMatch(name, t, v)
}
Expand All @@ -286,10 +303,14 @@ func (r *Resolver) onDouble(g *Scope, name string, t *parser.Type, v *parser.Con
if s == "false" {
return "0.0", nil
}
if val, ok := r.getIDValue(g, v.Extra); ok {
return val, nil
val, cate, ok := r.getIDValue(g, v.Extra)
if !ok {
return "", fmt.Errorf("undefined value: %q", s)
}
if err := r.typeMatch(t, cate, name); err != nil {
return "", err
}
return "", fmt.Errorf("undefined value: %q", s)
return val, nil
}
return "", errTypeMissMatch(name, t, v)
}
Expand All @@ -310,10 +331,14 @@ func (r *Resolver) onStrBin(g *Scope, name string, t *parser.Type, v *parser.Con
break
}

if val, ok := r.getIDValue(g, v.Extra); ok {
return val, nil
val, cate, ok := r.getIDValue(g, v.Extra)
if !ok {
return "", fmt.Errorf("undefined value: %q", s)
}
return "", fmt.Errorf("undefined value: %q", s)
if err := r.typeMatch(t, cate, name); err != nil {
return "", err
}
return val, nil
default:
}
return "", errTypeMissMatch(name, t, v)
Expand All @@ -324,10 +349,14 @@ func (r *Resolver) onEnum(g *Scope, name string, t *parser.Type, v *parser.Const
case parser.ConstType_ConstInt:
return fmt.Sprintf("%d", v.TypedValue.GetInt()), nil
case parser.ConstType_ConstIdentifier:
val, ok := r.getIDValue(g, v.Extra)
if ok {
return val, nil
val, cate, ok := r.getIDValue(g, v.Extra)
if !ok {
return "", fmt.Errorf("undefined value: %q", v.TypedValue.GetIdentifier())
}
if err := r.typeMatch(t, cate, name); err != nil {
return "", err
}
return val, nil
}
return "", fmt.Errorf("expect const value for %q is a int or enum, got %+v", name, v)
}
Expand All @@ -354,8 +383,14 @@ func (r *Resolver) onSetOrList(g *Scope, name string, t *parser.Type, v *parser.
return fmt.Sprintf("%s{\n%s\n}", goType, strings.Join(ss, "\n")), nil

case parser.ConstType_ConstIdentifier:
val, ok := r.getIDValue(g, v.Extra)
if ok && val != "true" && val != "false" {
val, cate, ok := r.getIDValue(g, v.Extra)
if !ok {
return "", fmt.Errorf("undefined value: %q", v.TypedValue.GetIdentifier())
}
if err := r.typeMatch(t, cate, name); err != nil {
return "", err
}
if val != "true" && val != "false" {
return val, nil
}

Expand Down Expand Up @@ -391,8 +426,14 @@ func (r *Resolver) onMap(g *Scope, name string, t *parser.Type, v *parser.ConstV
return fmt.Sprintf("%s{\n%s\n}", goType, strings.Join(kvs, "\n")), nil

case parser.ConstType_ConstIdentifier:
val, ok := r.getIDValue(g, v.Extra)
if ok && val != "true" && val != "false" {
val, cate, ok := r.getIDValue(g, v.Extra)
if !ok {
return "", fmt.Errorf("undefined value: %q", v.TypedValue.GetIdentifier())
}
if err := r.typeMatch(t, cate, name); err != nil {
return "", err
}
if val != "true" && val != "false" {
return val, nil
}
}
Expand All @@ -406,8 +447,14 @@ func (r *Resolver) onStructLike(g *Scope, name string, t *parser.Type, v *parser
return "", err
}
if v.Type == parser.ConstType_ConstIdentifier {
val, ok := r.getIDValue(g, v.Extra)
if ok && val != "true" && val != "false" {
val, cate, ok := r.getIDValue(g, v.Extra)
if !ok {
return "", fmt.Errorf("undefined value: %q", v.TypedValue.GetIdentifier())
}
if err := r.typeMatch(t, cate, name); err != nil {
return "", err
}
if val != "true" && val != "false" {
return val, nil
}
}
Expand Down Expand Up @@ -450,7 +497,7 @@ func (r *Resolver) onStructLike(g *Scope, name string, t *parser.Type, v *parser
}

if NeedRedirect(f) {
if f.Type.Category.IsBaseType() {
if IsBaseType(f.Type) {
// a trick to create pointers without temporary variables
val = fmt.Sprintf("(&struct{x %s}{%s}).x", typ, val)
}
Expand Down Expand Up @@ -493,6 +540,97 @@ func (r *Resolver) getStructLike(g *Scope, t *parser.Type) (f *Scope, s *parser.
return
}

func (r *Resolver) typeMatch(field *parser.Type, value *parser.Type, name string) error {
if field.Category.IsBool() {
if !value.Category.IsBool() {
return fmt.Errorf("type of %s is not bool type", name)
}
return nil
}
if field.Category.IsInteger() {
if !value.Category.IsDigital() {
return fmt.Errorf("type of %s is not digital type", name)
}
return nil
}
if field.Category.IsDouble() {
if !value.Category.IsDouble() {
return fmt.Errorf("type of %s is not double type", name)
}
return nil
}
if field.Category.IsString() {
if !value.Category.IsString() {
return fmt.Errorf("type of %s is not string type", name)
}
return nil
}
if field.Category.IsBinary() {
if !value.Category.IsString() && !value.Category.IsBinary() {
return fmt.Errorf("type of %s is not string or binary type", name)
}
return nil
}
if field.Category.IsEnum() {
if !value.Category.IsEnum() {
return fmt.Errorf("type of %s is not enum type", name)
}
if field.NameWithReference() != value.NameWithReference() {
return fmt.Errorf("enum type of %s is not %s, %s", name, field.NameWithReference(), value.NameWithReference())
}
return nil
}
if field.Category.IsSet() {
if !value.Category.IsSet() {
return fmt.Errorf("type of %s is not set type", name)
}
return r.typeMatch(field.ValueType, value.ValueType, name)
}
if field.Category.IsList() {
if !value.Category.IsList() && !value.Category.IsSet() {
return fmt.Errorf("type of %s is not set or list type", name)
}
return r.typeMatch(field.ValueType, value.ValueType, name)
}
if field.Category.IsMap() {
if !value.Category.IsMap() {
return fmt.Errorf("type of %s is not map type", name)
}
if err := r.typeMatch(field.KeyType, value.KeyType, name); err != nil {
return err
}
return r.typeMatch(field.ValueType, value.ValueType, name)
}
if field.Category.IsStruct() {
if !value.Category.IsStruct() {
return fmt.Errorf("type of %s is not struct type", name)
}
if field.NameWithReference() != value.NameWithReference() {
return fmt.Errorf("type of %s is not %s", name, field.NameWithReference())
}
return nil
}
if field.Category.IsUnion() {
if !value.Category.IsUnion() {
return fmt.Errorf("type of %s is not union type", name)
}
if field.NameWithReference() != value.NameWithReference() {
return fmt.Errorf("type of %s is not %s", name, field.NameWithReference())
}
return nil
}
if field.Category.IsException() {
if !value.Category.IsException() {
return fmt.Errorf("type of %s is not exception type", name)
}
if field.NameWithReference() != value.NameWithReference() {
return fmt.Errorf("type of %s is not %s", name, field.NameWithReference())
}
return nil
}
return fmt.Errorf("type of %s not matched %s", name, field.NameWithReference())
}

func (r *Resolver) bin2str(t *parser.Type) *parser.Type {
if t.Category == parser.Category_Binary {
r := *t
Expand Down
9 changes: 9 additions & 0 deletions parser/AST-extend-category.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,12 @@ func (p Category) IsContainerType() bool {
func (p Category) IsStructLike() bool {
return p == Category_Struct || p == Category_Union || p == Category_Exception
}

func (p Category) IsInteger() bool {
return p == Category_Byte || p == Category_I16 || p == Category_I32 || p == Category_I64
}

func (p Category) IsDigital() bool {
return p == Category_Byte || p == Category_I16 || p == Category_I32 ||
p == Category_I64 || p == Category_Double || p == Category_Enum
}
7 changes: 7 additions & 0 deletions parser/AST-extend.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,13 @@ func (t *Type) String() string {
return t.Name
}

func (t *Type) NameWithReference() string {
if t.Reference != nil && t.Reference.Name != "" {
return t.Reference.Name
}
return t.Name
}

// GetField returns a field of the struct-like that matches the name.
func (s *StructLike) GetField(name string) (*Field, bool) {
for _, fi := range s.Fields {
Expand Down

0 comments on commit 3790b20

Please sign in to comment.