-
Notifications
You must be signed in to change notification settings - Fork 68
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #89 from mbarnes/enforce-visibility
Enforce visibility struct tags
- Loading branch information
Showing
5 changed files
with
1,295 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,292 @@ | ||
package api | ||
|
||
// Copyright (c) Microsoft Corporation. | ||
// Licensed under the Apache License 2.0. | ||
|
||
import ( | ||
"fmt" | ||
"reflect" | ||
"strings" | ||
|
||
"github.com/Azure/ARO-HCP/internal/api/arm" | ||
) | ||
|
||
// Property visibility meanings: | ||
// https://azure.github.io/typespec-azure/docs/howtos/ARM/resource-type#property-visibility-and-other-constraints | ||
// | ||
// Field mutability guidelines: | ||
// https://github.com/microsoft/api-guidelines/blob/vNext/azure/Guidelines.md#resource-schema--field-mutability | ||
|
||
const VisibilityStructTagKey = "visibility" | ||
|
||
// VisibilityFlags holds a visibility struct tag value as bit flags. | ||
type VisibilityFlags uint8 | ||
|
||
const ( | ||
VisibilityRead VisibilityFlags = 1 << iota | ||
VisibilityCreate | ||
VisibilityUpdate | ||
|
||
// option flags | ||
VisibilityCaseInsensitive | ||
|
||
VisibilityDefault = VisibilityRead | VisibilityCreate | VisibilityUpdate | ||
) | ||
|
||
func (f VisibilityFlags) ReadOnly() bool { | ||
return f&(VisibilityRead|VisibilityCreate|VisibilityUpdate) == VisibilityRead | ||
} | ||
|
||
func (f VisibilityFlags) CanUpdate() bool { | ||
return f&VisibilityUpdate != 0 | ||
} | ||
|
||
func (f VisibilityFlags) CaseInsensitive() bool { | ||
return f&VisibilityCaseInsensitive != 0 | ||
} | ||
|
||
func (f VisibilityFlags) String() string { | ||
s := []string{} | ||
if f&VisibilityRead != 0 { | ||
s = append(s, "read") | ||
} | ||
if f&VisibilityCreate != 0 { | ||
s = append(s, "create") | ||
} | ||
if f&VisibilityUpdate != 0 { | ||
s = append(s, "update") | ||
} | ||
if f&VisibilityCaseInsensitive != 0 { | ||
s = append(s, "nocase") | ||
} | ||
return strings.Join(s, " ") | ||
} | ||
|
||
func GetVisibilityFlags(tag reflect.StructTag) (VisibilityFlags, bool) { | ||
var flags VisibilityFlags | ||
|
||
tagValue, ok := tag.Lookup(VisibilityStructTagKey) | ||
if ok { | ||
for _, v := range strings.Fields(tagValue) { | ||
switch strings.ToLower(v) { | ||
case "read": | ||
flags |= VisibilityRead | ||
case "create": | ||
flags |= VisibilityCreate | ||
case "update": | ||
flags |= VisibilityUpdate | ||
case "nocase": | ||
flags |= VisibilityCaseInsensitive | ||
default: | ||
panic(fmt.Sprintf("Unknown visibility tag value '%s'", v)) | ||
} | ||
} | ||
} | ||
|
||
return flags, ok | ||
} | ||
|
||
func join(ns, name string) string { | ||
res := ns | ||
if res != "" { | ||
res += "." | ||
} | ||
res += name | ||
return res | ||
} | ||
|
||
type StructTagMap map[string]reflect.StructTag | ||
|
||
func buildStructTagMap(structTagMap StructTagMap, t reflect.Type, path string) { | ||
switch t.Kind() { | ||
case reflect.Pointer, reflect.Slice: | ||
buildStructTagMap(structTagMap, t.Elem(), path) | ||
|
||
case reflect.Struct: | ||
for i := 0; i < t.NumField(); i++ { | ||
field := t.Field(i) | ||
subpath := join(path, field.Name) | ||
|
||
if len(field.Tag) > 0 { | ||
structTagMap[subpath] = field.Tag | ||
} | ||
|
||
buildStructTagMap(structTagMap, field.Type, subpath) | ||
} | ||
} | ||
} | ||
|
||
// NewStructTagMap returns a mapping of dot-separated struct field names | ||
// to struct tags for the given type. Each versioned API should create | ||
// its own visibiilty map for tracked resource types. | ||
// | ||
// Note: This assumes field names for internal and versioned structs are | ||
// identical where visibility is explicitly specified. If some divergence | ||
// emerges, one workaround could be to pass a field name override map. | ||
func NewStructTagMap[T any]() StructTagMap { | ||
structTagMap := StructTagMap{} | ||
buildStructTagMap(structTagMap, reflect.TypeFor[T](), "") | ||
return structTagMap | ||
} | ||
|
||
type validateVisibility struct { | ||
structTagMap StructTagMap | ||
updating bool | ||
errs []arm.CloudErrorBody | ||
} | ||
|
||
// ValidateVisibility compares the new value (newVal) to the current value | ||
// (curVal) and returns any violations of visibility restrictions as defined | ||
// by structTagMap. | ||
func ValidateVisibility(newVal, curVal interface{}, structTagMap StructTagMap, updating bool) []arm.CloudErrorBody { | ||
vv := validateVisibility{ | ||
structTagMap: structTagMap, | ||
updating: updating, | ||
} | ||
vv.recurse(reflect.ValueOf(newVal), reflect.ValueOf(curVal), "", "", "", VisibilityDefault) | ||
return vv.errs | ||
} | ||
|
||
// mapKey is a lookup key for the StructTagMap. It DOES NOT include subscripts | ||
// for arrays, maps or slices since all elements are the same type. | ||
// | ||
// namespace is the struct field path up to but not including the field being | ||
// evaluated, analogous to path.Dir. It DOES include subscripts for arrays, | ||
// maps and slices since its purpose is for error reporting. | ||
// | ||
// fieldname is the current field being evaluated, analgous to path.Base. It | ||
// also includes subscripts for arrays, maps and slices when evaluating their | ||
// immediate elements. | ||
func (vv *validateVisibility) recurse(newVal, curVal reflect.Value, mapKey, namespace, fieldname string, implicitVisibility VisibilityFlags) { | ||
flags, ok := GetVisibilityFlags(vv.structTagMap[mapKey]) | ||
if !ok { | ||
flags = implicitVisibility | ||
} | ||
|
||
if newVal.Type() != curVal.Type() { | ||
panic(fmt.Sprintf("%s: value types differ (%s vs %s)", join(namespace, fieldname), newVal.Type().Name(), curVal.Type().Name())) | ||
} | ||
|
||
// Generated API structs are all pointer fields. A nil pointer in | ||
// the incoming request (newVal) means the value is absent, which | ||
// is always acceptable for visibility validation. | ||
switch newVal.Kind() { | ||
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Pointer, reflect.Slice: | ||
if newVal.IsNil() { | ||
return | ||
} | ||
} | ||
|
||
switch newVal.Kind() { | ||
case reflect.Bool: | ||
if newVal.Bool() != curVal.Bool() { | ||
vv.checkFlags(flags, namespace, fieldname) | ||
} | ||
|
||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | ||
if newVal.Int() != curVal.Int() { | ||
vv.checkFlags(flags, namespace, fieldname) | ||
} | ||
|
||
case reflect.Uint, reflect.Uintptr, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: | ||
if newVal.Uint() != curVal.Uint() { | ||
vv.checkFlags(flags, namespace, fieldname) | ||
} | ||
|
||
case reflect.Float32, reflect.Float64: | ||
if newVal.Float() != curVal.Float() { | ||
vv.checkFlags(flags, namespace, fieldname) | ||
} | ||
|
||
case reflect.Complex64, reflect.Complex128: | ||
if newVal.Complex() != curVal.Complex() { | ||
vv.checkFlags(flags, namespace, fieldname) | ||
} | ||
|
||
case reflect.String: | ||
if flags.CaseInsensitive() { | ||
if !strings.EqualFold(newVal.String(), curVal.String()) { | ||
vv.checkFlags(flags, namespace, fieldname) | ||
} | ||
} else { | ||
if newVal.String() != curVal.String() { | ||
vv.checkFlags(flags, namespace, fieldname) | ||
} | ||
} | ||
|
||
case reflect.Slice: | ||
// We already know that newVal is not nil. | ||
if curVal.IsNil() { | ||
vv.checkFlags(flags, namespace, fieldname) | ||
return | ||
} | ||
|
||
fallthrough | ||
|
||
case reflect.Array: | ||
if newVal.Len() != curVal.Len() { | ||
vv.checkFlags(flags, namespace, fieldname) | ||
} else { | ||
for i := 0; i < min(newVal.Len(), curVal.Len()); i++ { | ||
subscript := fmt.Sprintf("[%d]", i) | ||
vv.recurse(newVal.Index(i), curVal.Index(i), mapKey, namespace, fieldname+subscript, flags) | ||
} | ||
} | ||
|
||
case reflect.Interface, reflect.Pointer: | ||
// We already know that newVal is not nil. | ||
if curVal.IsNil() { | ||
vv.checkFlags(flags, namespace, fieldname) | ||
} else { | ||
vv.recurse(newVal.Elem(), curVal.Elem(), mapKey, namespace, fieldname, flags) | ||
} | ||
|
||
case reflect.Map: | ||
// We already know that newVal is not nil. | ||
if curVal.IsNil() || newVal.Len() != curVal.Len() { | ||
vv.checkFlags(flags, namespace, fieldname) | ||
} else { | ||
iter := newVal.MapRange() | ||
for iter.Next() { | ||
k := iter.Key() | ||
|
||
subscript := fmt.Sprintf("[%q]", k.Interface()) | ||
if curVal.MapIndex(k).IsValid() { | ||
vv.recurse(newVal.MapIndex(k), curVal.MapIndex(k), mapKey, namespace, fieldname+subscript, flags) | ||
} else { | ||
vv.checkFlags(flags, namespace, fieldname+subscript) | ||
} | ||
} | ||
} | ||
|
||
case reflect.Struct: | ||
for i := 0; i < newVal.NumField(); i++ { | ||
structField := newVal.Type().Field(i) | ||
mapKeyNext := join(mapKey, structField.Name) | ||
namespaceNext := join(namespace, fieldname) | ||
fieldnameNext := GetJSONTagName(vv.structTagMap[mapKeyNext]) | ||
if fieldnameNext == "" { | ||
fieldnameNext = structField.Name | ||
} | ||
vv.recurse(newVal.Field(i), curVal.Field(i), mapKeyNext, namespaceNext, fieldnameNext, flags) | ||
} | ||
} | ||
} | ||
|
||
func (vv *validateVisibility) checkFlags(flags VisibilityFlags, namespace, fieldname string) { | ||
if flags.ReadOnly() { | ||
vv.errs = append(vv.errs, | ||
arm.CloudErrorBody{ | ||
Code: arm.CloudErrorCodeInvalidRequestContent, | ||
Message: fmt.Sprintf("Field '%s' is read-only", fieldname), | ||
Target: join(namespace, fieldname), | ||
}) | ||
} else if vv.updating && !flags.CanUpdate() { | ||
vv.errs = append(vv.errs, | ||
arm.CloudErrorBody{ | ||
Code: arm.CloudErrorCodeInvalidRequestContent, | ||
Message: fmt.Sprintf("Field '%s' cannot be updated", fieldname), | ||
Target: join(namespace, fieldname), | ||
}) | ||
} | ||
} |
Oops, something went wrong.