diff --git a/generator/golang/resolver.go b/generator/golang/resolver.go index 7d133d0f..76ab5a3a 100644 --- a/generator/golang/resolver.go +++ b/generator/golang/resolver.go @@ -157,20 +157,29 @@ func (r *Resolver) getContainerTypeName(g *Scope, t *parser.Type) (name string, // getIDValue returns the literal representation of a const value. // The extra must be associated with g and from a const value that has // type parser.ConstType_ConstIdentifier. -func (r *Resolver) getIDValue(g *Scope, extra *parser.ConstValueExtra) (v string, ok bool) { +func (r *Resolver) getIDValue(g *Scope, extra *parser.ConstValueExtra) (v string, t *parser.Type, ok bool) { if extra.Index == -1 { if extra.IsEnum { enum, ok := g.ast.GetEnum(extra.Sel) if !ok { - return "", false + return "", t, false } if en := g.Enum(enum.Name); en != nil { if ev := en.Value(extra.Name); ev != nil { v = ev.GoName().String() + t = &parser.Type{ + Name: enum.Name, + Category: parser.Category_Enum, + } } } } else { v = g.globals.Get(extra.Name) + con, ok := g.ast.GetConstant(extra.Name) + if !ok { + return "", t, false + } + t = con.Type } } else { g = g.includes[extra.Index].Scope @@ -186,7 +195,7 @@ func (r *Resolver) getIDValue(g *Scope, extra *parser.ConstValueExtra) (v string pkg := r.root.includeIDL(r.util, g.ast) v = pkg + "." + v } - return v, v != "" + return v, t, v != "" } // ResolveConst returns the initialization code for a constant or a default value. @@ -239,10 +248,14 @@ func (r *Resolver) onBool(g *Scope, name string, t *parser.Type, v *parser.Const return s, nil } - if val, ok := r.getIDValue(g, v.Extra); ok { - return val, nil + val, cate, ok := r.getIDValue(g, v.Extra) + if !ok { + return "", fmt.Errorf("undefined value: %q", s) } - return "", fmt.Errorf("undefined value: %q", s) + if err := r.typeMatch(t, cate, name); err != nil { + return "", err + } + return val, nil } return "", errTypeMissMatch(name, t, v) } @@ -260,12 +273,16 @@ func (r *Resolver) onInt(g *Scope, name string, t *parser.Type, v *parser.ConstV if s == "false" { return "0", nil } - if val, ok := r.getIDValue(g, v.Extra); ok { - goType, _ := r.getTypeName(g, t) - val = fmt.Sprintf("%s(%s)", goType, val) - return val, nil + val, cate, ok := r.getIDValue(g, v.Extra) + if !ok { + return "", fmt.Errorf("undefined value: %q", s) } - return "", fmt.Errorf("undefined value: %q", s) + if err := r.typeMatch(t, cate, name); err != nil { + return "", err + } + goType, _ := r.getTypeName(g, t) + val = fmt.Sprintf("%s(%s)", goType, val) + return val, nil } return "", errTypeMissMatch(name, t, v) } @@ -286,10 +303,14 @@ func (r *Resolver) onDouble(g *Scope, name string, t *parser.Type, v *parser.Con if s == "false" { return "0.0", nil } - if val, ok := r.getIDValue(g, v.Extra); ok { - return val, nil + val, cate, ok := r.getIDValue(g, v.Extra) + if !ok { + return "", fmt.Errorf("undefined value: %q", s) + } + if err := r.typeMatch(t, cate, name); err != nil { + return "", err } - return "", fmt.Errorf("undefined value: %q", s) + return val, nil } return "", errTypeMissMatch(name, t, v) } @@ -310,10 +331,14 @@ func (r *Resolver) onStrBin(g *Scope, name string, t *parser.Type, v *parser.Con break } - if val, ok := r.getIDValue(g, v.Extra); ok { - return val, nil + val, cate, ok := r.getIDValue(g, v.Extra) + if !ok { + return "", fmt.Errorf("undefined value: %q", s) } - return "", fmt.Errorf("undefined value: %q", s) + if err := r.typeMatch(t, cate, name); err != nil { + return "", err + } + return val, nil default: } return "", errTypeMissMatch(name, t, v) @@ -324,10 +349,14 @@ func (r *Resolver) onEnum(g *Scope, name string, t *parser.Type, v *parser.Const case parser.ConstType_ConstInt: return fmt.Sprintf("%d", v.TypedValue.GetInt()), nil case parser.ConstType_ConstIdentifier: - val, ok := r.getIDValue(g, v.Extra) - if ok { - return val, nil + val, cate, ok := r.getIDValue(g, v.Extra) + if !ok { + return "", fmt.Errorf("undefined value: %q", v.TypedValue.GetIdentifier()) + } + if err := r.typeMatch(t, cate, name); err != nil { + return "", err } + return val, nil } return "", fmt.Errorf("expect const value for %q is a int or enum, got %+v", name, v) } @@ -354,8 +383,14 @@ func (r *Resolver) onSetOrList(g *Scope, name string, t *parser.Type, v *parser. return fmt.Sprintf("%s{\n%s\n}", goType, strings.Join(ss, "\n")), nil case parser.ConstType_ConstIdentifier: - val, ok := r.getIDValue(g, v.Extra) - if ok && val != "true" && val != "false" { + val, cate, ok := r.getIDValue(g, v.Extra) + if !ok { + return "", fmt.Errorf("undefined value: %q", v.TypedValue.GetIdentifier()) + } + if err := r.typeMatch(t, cate, name); err != nil { + return "", err + } + if val != "true" && val != "false" { return val, nil } @@ -391,8 +426,14 @@ func (r *Resolver) onMap(g *Scope, name string, t *parser.Type, v *parser.ConstV return fmt.Sprintf("%s{\n%s\n}", goType, strings.Join(kvs, "\n")), nil case parser.ConstType_ConstIdentifier: - val, ok := r.getIDValue(g, v.Extra) - if ok && val != "true" && val != "false" { + val, cate, ok := r.getIDValue(g, v.Extra) + if !ok { + return "", fmt.Errorf("undefined value: %q", v.TypedValue.GetIdentifier()) + } + if err := r.typeMatch(t, cate, name); err != nil { + return "", err + } + if val != "true" && val != "false" { return val, nil } } @@ -406,8 +447,14 @@ func (r *Resolver) onStructLike(g *Scope, name string, t *parser.Type, v *parser return "", err } if v.Type == parser.ConstType_ConstIdentifier { - val, ok := r.getIDValue(g, v.Extra) - if ok && val != "true" && val != "false" { + val, cate, ok := r.getIDValue(g, v.Extra) + if !ok { + return "", fmt.Errorf("undefined value: %q", v.TypedValue.GetIdentifier()) + } + if err := r.typeMatch(t, cate, name); err != nil { + return "", err + } + if val != "true" && val != "false" { return val, nil } } @@ -450,7 +497,7 @@ func (r *Resolver) onStructLike(g *Scope, name string, t *parser.Type, v *parser } if NeedRedirect(f) { - if f.Type.Category.IsBaseType() { + if IsBaseType(f.Type) { // a trick to create pointers without temporary variables val = fmt.Sprintf("(&struct{x %s}{%s}).x", typ, val) } @@ -493,6 +540,97 @@ func (r *Resolver) getStructLike(g *Scope, t *parser.Type) (f *Scope, s *parser. return } +func (r *Resolver) typeMatch(field *parser.Type, value *parser.Type, name string) error { + if field.Category.IsBool() { + if !value.Category.IsBool() { + return fmt.Errorf("type of %s is not bool type", name) + } + return nil + } + if field.Category.IsInteger() { + if !value.Category.IsDigital() { + return fmt.Errorf("type of %s is not digital type", name) + } + return nil + } + if field.Category.IsDouble() { + if !value.Category.IsDouble() { + return fmt.Errorf("type of %s is not double type", name) + } + return nil + } + if field.Category.IsString() { + if !value.Category.IsString() { + return fmt.Errorf("type of %s is not string type", name) + } + return nil + } + if field.Category.IsBinary() { + if !value.Category.IsString() && !value.Category.IsBinary() { + return fmt.Errorf("type of %s is not string or binary type", name) + } + return nil + } + if field.Category.IsEnum() { + if !value.Category.IsEnum() { + return fmt.Errorf("type of %s is not enum type", name) + } + if field.Name != value.Name { + return fmt.Errorf("enum type of %s is not %s", name, field.Name) + } + return nil + } + if field.Category.IsSet() { + if !value.Category.IsSet() { + return fmt.Errorf("type of %s is not set type", name) + } + return r.typeMatch(field.ValueType, value.ValueType, name) + } + if field.Category.IsList() { + if !value.Category.IsList() && !value.Category.IsSet() { + return fmt.Errorf("type of %s is not set or list type", name) + } + return r.typeMatch(field.ValueType, value.ValueType, name) + } + if field.Category.IsMap() { + if !value.Category.IsMap() { + return fmt.Errorf("type of %s is not map type", name) + } + if err := r.typeMatch(field.KeyType, value.KeyType, name); err != nil { + return err + } + return r.typeMatch(field.ValueType, value.ValueType, name) + } + if field.Category.IsStruct() { + if !value.Category.IsStruct() { + return fmt.Errorf("type of %s is not struct type", name) + } + if field.Name != value.Name { + return fmt.Errorf("type of %s is not %s", name, field.Name) + } + return nil + } + if field.Category.IsUnion() { + if !value.Category.IsUnion() { + return fmt.Errorf("type of %s is not union type", name) + } + if field.Name != value.Name { + return fmt.Errorf("type of %s is not %s", name, field.Name) + } + return nil + } + if field.Category.IsException() { + if !value.Category.IsException() { + return fmt.Errorf("type of %s is not exception type", name) + } + if field.Name != value.Name { + return fmt.Errorf("type of %s is not %s", name, field.Name) + } + return nil + } + return fmt.Errorf("type of %s not matched %s", name, field.Name) +} + func (r *Resolver) bin2str(t *parser.Type) *parser.Type { if t.Category == parser.Category_Binary { r := *t diff --git a/parser/AST-extend-category.go b/parser/AST-extend-category.go index 5f7c486c..16f3dbe2 100644 --- a/parser/AST-extend-category.go +++ b/parser/AST-extend-category.go @@ -118,3 +118,12 @@ func (p Category) IsContainerType() bool { func (p Category) IsStructLike() bool { return p == Category_Struct || p == Category_Union || p == Category_Exception } + +func (p Category) IsInteger() bool { + return p == Category_Byte || p == Category_I16 || p == Category_I32 || p == Category_I64 +} + +func (p Category) IsDigital() bool { + return p == Category_Byte || p == Category_I16 || p == Category_I32 || + p == Category_I64 || p == Category_Double || p == Category_Enum +}