Skip to content

Commit

Permalink
fix conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
Max Mitchell committed Jan 12, 2021
1 parent 4906a18 commit 4091bb1
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 45 deletions.
39 changes: 27 additions & 12 deletions internal/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ func ValidateDomains(user string, domains CommaSeparatedList) bool {
return true
}
}

return false
}

Expand Down Expand Up @@ -197,23 +196,31 @@ func ClearCookie(r *http.Request) *http.Cookie {
}
}

func buildCSRFCookieName(nonce string) string {
return config.CSRFCookieName + "_" + nonce[:6]
}

// MakeCSRFCookie makes a csrf cookie (used during login only)
//
// Note, CSRF cookies live shorter than auth cookies, a fixed 1h.
// That's because some CSRF cookies may belong to auth flows that don't complete
// and thus may not get cleared by ClearCookie.
func MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie {
return &http.Cookie{
Name: config.CSRFCookieName,
Name: buildCSRFCookieName(nonce),
Value: nonce,
Path: "/",
Domain: csrfCookieDomain(r),
HttpOnly: true,
Secure: !config.InsecureCookie,
Expires: cookieExpiry(),
Expires: time.Now().Local().Add(time.Hour * 1),
}
}

// ClearCSRFCookie makes an expired csrf cookie to clear csrf cookie
func ClearCSRFCookie(r *http.Request) *http.Cookie {
func ClearCSRFCookie(r *http.Request, c *http.Cookie) *http.Cookie {
return &http.Cookie{
Name: config.CSRFCookieName,
Name: c.Name,
Value: "",
Path: "/",
Domain: csrfCookieDomain(r),
Expand All @@ -223,18 +230,18 @@ func ClearCSRFCookie(r *http.Request) *http.Cookie {
}
}

// ValidateCSRFCookie validates the csrf cookie against state
func ValidateCSRFCookie(r *http.Request, c *http.Cookie) (valid bool, provider string, redirect string, err error) {
state := r.URL.Query().Get("state")
// FindCSRFCookie extracts the CSRF cookie from the request based on state.
func FindCSRFCookie(r *http.Request, state string) (c *http.Cookie, err error) {
// Check for CSRF cookie
return r.Cookie(buildCSRFCookieName(state))
}

// ValidateCSRFCookie validates the csrf cookie against state
func ValidateCSRFCookie(c *http.Cookie, state string) (valid bool, provider string, redirect string, err error) {
if len(c.Value) != 32 {
return false, "", "", errors.New("Invalid CSRF cookie value")
}

if len(state) < 34 {
return false, "", "", errors.New("Invalid CSRF state value")
}

// Check nonce match
if c.Value != state[:32] {
return false, "", "", errors.New("CSRF cookie does not match state")
Expand All @@ -256,6 +263,14 @@ func MakeState(r *http.Request, p provider.Provider, nonce string) string {
return fmt.Sprintf("%s:%s:%s", nonce, p.Name(), returnUrl(r))
}

// ValidateState checks whether the state is of right length.
func ValidateState(state string) error {
if len(state) < 34 {
return errors.New("Invalid CSRF state value")
}
return nil
}

// Nonce generates a random nonce
func Nonce() (error, string) {
nonce := make([]byte, 16)
Expand Down
64 changes: 32 additions & 32 deletions internal/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,53 +61,53 @@ func TestAuthValidateCookie(t *testing.T) {
assert.Equal("[email protected]", email, "valid request should return user email")
}

func TestAuthValidateEmail(t *testing.T) {
func TestAuthValidateUser(t *testing.T) {
assert := assert.New(t)
config, _ = NewConfig([]string{})

// Should allow any with no whitelist/domain is specified
v := ValidateEmail("[email protected]", "default")
v := ValidateUser("[email protected]", "default")
assert.True(v, "should allow any domain if email domain is not defined")
v = ValidateEmail("[email protected]", "default")
v = ValidateUser("[email protected]", "default")
assert.True(v, "should allow any domain if email domain is not defined")

// Should allow matching domain
config.Domains = []string{"test.com"}
v = ValidateEmail("[email protected]", "default")
v = ValidateUser("[email protected]", "default")
assert.False(v, "should not allow user from another domain")
v = ValidateEmail("[email protected]", "default")
v = ValidateUser("[email protected]", "default")
assert.True(v, "should allow user from allowed domain")

// Should allow matching whitelisted email address
config.Domains = []string{}
config.Whitelist = []string{"[email protected]"}
v = ValidateEmail("[email protected]", "default")
v = ValidateUser("[email protected]", "default")
assert.False(v, "should not allow user not in whitelist")
v = ValidateEmail("[email protected]", "default")
v = ValidateUser("[email protected]", "default")
assert.True(v, "should allow user in whitelist")

// Should allow only matching email address when
// MatchWhitelistOrDomain is disabled
config.Domains = []string{"example.com"}
config.Whitelist = []string{"[email protected]"}
config.MatchWhitelistOrDomain = false
v = ValidateEmail("[email protected]", "default")
v = ValidateUser("[email protected]", "default")
assert.False(v, "should not allow user not in either")
v = ValidateEmail("[email protected]", "default")
v = ValidateUser("[email protected]", "default")
assert.False(v, "should not allow user from allowed domain")
v = ValidateEmail("[email protected]", "default")
v = ValidateUser("[email protected]", "default")
assert.True(v, "should allow user in whitelist")

// Should allow either matching domain or email address when
// MatchWhitelistOrDomain is enabled
config.Domains = []string{"example.com"}
config.Whitelist = []string{"[email protected]"}
config.MatchWhitelistOrDomain = true
v = ValidateEmail("[email protected]", "default")
v = ValidateUser("[email protected]", "default")
assert.False(v, "should not allow user not in either")
v = ValidateEmail("[email protected]", "default")
v = ValidateUser("[email protected]", "default")
assert.True(v, "should allow user from allowed domain")
v = ValidateEmail("[email protected]", "default")
v = ValidateUser("[email protected]", "default")
assert.True(v, "should allow user in whitelist")

// Rule testing
Expand All @@ -117,11 +117,11 @@ func TestAuthValidateEmail(t *testing.T) {
config.Whitelist = []string{"[email protected]"}
config.Rules = map[string]*Rule{"test": NewRule()}
config.MatchWhitelistOrDomain = true
v = ValidateEmail("[email protected]", "test")
v = ValidateUser("[email protected]", "test")
assert.False(v, "should not allow user not in either")
v = ValidateEmail("[email protected]", "test")
v = ValidateUser("[email protected]", "test")
assert.True(v, "should allow user from allowed global domain")
v = ValidateEmail("[email protected]", "test")
v = ValidateUser("[email protected]", "test")
assert.True(v, "should allow user in global whitelist")

// Should allow matching domain in rule
Expand All @@ -131,11 +131,11 @@ func TestAuthValidateEmail(t *testing.T) {
config.Rules = map[string]*Rule{"test": rule}
rule.Domains = []string{"testrule.com"}
config.MatchWhitelistOrDomain = false
v = ValidateEmail("[email protected]", "test")
v = ValidateUser("[email protected]", "test")
assert.False(v, "should not allow user from another domain")
v = ValidateEmail("[email protected]", "test")
v = ValidateUser("[email protected]", "test")
assert.False(v, "should not allow user from global domain")
v = ValidateEmail("[email protected]", "test")
v = ValidateUser("[email protected]", "test")
assert.True(v, "should allow user from allowed domain")

// Should allow matching whitelist in rule
Expand All @@ -145,11 +145,11 @@ func TestAuthValidateEmail(t *testing.T) {
config.Rules = map[string]*Rule{"test": rule}
rule.Whitelist = []string{"[email protected]"}
config.MatchWhitelistOrDomain = false
v = ValidateEmail("[email protected]", "test")
v = ValidateUser("[email protected]", "test")
assert.False(v, "should not allow user from another domain")
v = ValidateEmail("[email protected]", "test")
v = ValidateUser("[email protected]", "test")
assert.False(v, "should not allow user from global domain")
v = ValidateEmail("[email protected]", "test")
v = ValidateUser("[email protected]", "test")
assert.True(v, "should allow user from allowed domain")

// Should allow only matching email address when
Expand All @@ -161,15 +161,15 @@ func TestAuthValidateEmail(t *testing.T) {
rule.Domains = []string{"examplerule.com"}
rule.Whitelist = []string{"[email protected]"}
config.MatchWhitelistOrDomain = false
v = ValidateEmail("[email protected]", "test")
v = ValidateUser("[email protected]", "test")
assert.False(v, "should not allow user not in either")
v = ValidateEmail("[email protected]", "test")
v = ValidateUser("[email protected]", "test")
assert.False(v, "should not allow user in global whitelist")
v = ValidateEmail("[email protected]", "test")
v = ValidateUser("[email protected]", "test")
assert.False(v, "should not allow user from global domain")
v = ValidateEmail("[email protected]", "test")
v = ValidateUser("[email protected]", "test")
assert.False(v, "should not allow user from allowed domain")
v = ValidateEmail("[email protected]", "test")
v = ValidateUser("[email protected]", "test")
assert.True(v, "should allow user in whitelist")

// Should allow either matching domain or email address when
Expand All @@ -181,15 +181,15 @@ func TestAuthValidateEmail(t *testing.T) {
rule.Domains = []string{"examplerule.com"}
rule.Whitelist = []string{"[email protected]"}
config.MatchWhitelistOrDomain = true
v = ValidateEmail("[email protected]", "test")
v = ValidateUser("[email protected]", "test")
assert.False(v, "should not allow user not in either")
v = ValidateEmail("[email protected]", "test")
v = ValidateUser("[email protected]", "test")
assert.False(v, "should not allow user in global whitelist")
v = ValidateEmail("[email protected]", "test")
v = ValidateUser("[email protected]", "test")
assert.False(v, "should not allow user from global domain")
v = ValidateEmail("[email protected]", "test")
v = ValidateUser("[email protected]", "test")
assert.True(v, "should allow user from allowed domain")
v = ValidateEmail("[email protected]", "test")
v = ValidateUser("[email protected]", "test")
assert.True(v, "should allow user in whitelist")
}

Expand Down
2 changes: 1 addition & 1 deletion internal/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func (s *Server) AuthHandler(providerName, rule string) http.HandlerFunc {
}

// Validate user
valid := ValidateUser(email, rule)
valid := ValidateUser(user, rule)
if !valid {
logger.WithField("user", user).Warn("Invalid user")
http.Error(w, fmt.Sprintf("User '%s' is not authorized", user), 401)
Expand Down

0 comments on commit 4091bb1

Please sign in to comment.