diff --git a/bascule/basculehttp/enforcer.go b/bascule/basculehttp/enforcer.go index dd7005e..56c5d76 100644 --- a/bascule/basculehttp/enforcer.go +++ b/bascule/basculehttp/enforcer.go @@ -6,8 +6,19 @@ import ( "github.com/Comcast/comcast-bascule/bascule" ) +//go:generate stringer -type=NotFoundBehavior + +type NotFoundBehavior int + +// Behavior on not found +const ( + Forbid NotFoundBehavior = iota + Allow +) + type enforcer struct { - rules map[bascule.Authorization]bascule.Validators + notFoundBehavior NotFoundBehavior + rules map[bascule.Authorization]bascule.Validators } func (e *enforcer) decorate(next http.Handler) http.Handler { @@ -20,13 +31,22 @@ func (e *enforcer) decorate(next http.Handler) http.Handler { } rules, ok := e.rules[auth.Authorization] if !ok { - response.WriteHeader(http.StatusForbidden) - return - } - err := rules.Check(ctx, auth.Token) - if err != nil { - WriteResponse(response, http.StatusUnauthorized, err) - return + switch e.notFoundBehavior { + case Forbid: + response.WriteHeader(http.StatusForbidden) + return + case Allow: + // continue + default: + response.WriteHeader(http.StatusForbidden) + return + } + } else { + err := rules.Check(ctx, auth.Token) + if err != nil { + WriteResponse(response, http.StatusUnauthorized, err) + return + } } next.ServeHTTP(response, request) }) @@ -34,6 +54,12 @@ func (e *enforcer) decorate(next http.Handler) http.Handler { type EOption func(*enforcer) +func WithNotFoundBehavior(behavior NotFoundBehavior) EOption { + return func(e *enforcer) { + e.notFoundBehavior = behavior + } +} + func WithRules(key bascule.Authorization, v bascule.Validators) EOption { return func(e *enforcer) { e.rules[key] = v diff --git a/bascule/basculehttp/notfoundbehavior_string.go b/bascule/basculehttp/notfoundbehavior_string.go new file mode 100644 index 0000000..2a58acb --- /dev/null +++ b/bascule/basculehttp/notfoundbehavior_string.go @@ -0,0 +1,16 @@ +// Code generated by "stringer -type=NotFoundBehavior"; DO NOT EDIT. + +package basculehttp + +import "strconv" + +const _NotFoundBehavior_name = "ForbidAllow" + +var _NotFoundBehavior_index = [...]uint8{0, 6, 11} + +func (i NotFoundBehavior) String() string { + if i < 0 || i >= NotFoundBehavior(len(_NotFoundBehavior_index)-1) { + return "NotFoundBehavior(" + strconv.FormatInt(int64(i), 10) + ")" + } + return _NotFoundBehavior_name[_NotFoundBehavior_index[i]:_NotFoundBehavior_index[i+1]] +} diff --git a/bascule/checks.go b/bascule/checks.go index 1e0f141..a6aff39 100644 --- a/bascule/checks.go +++ b/bascule/checks.go @@ -12,6 +12,12 @@ const ( capabilitiesKey = "capabilities" ) +func CreateAllowAllCheck() ValidatorFunc { + return func(_ context.Context, _ Token) error { + return nil + } +} + func CreateValidTypeCheck(validTypes []string) ValidatorFunc { return func(_ context.Context, token Token) error { tt := token.Type() @@ -65,3 +71,19 @@ func CreateListAttributeCheck(key string, checks ...func(context.Context, []inte return errs } } + +func NonEmptyStringListCheck(ctx context.Context, vals []interface{}) error { + if len(vals) == 0 { + return errors.New("expected at least one value") + } + for _, val := range vals { + str, ok := val.(string) + if !ok { + return errors.New("expected value to be a string") + } + if len(str) == 0 { + return errors.New("expected string to be nonempty") + } + } + return nil +}