diff --git a/internal/auth.go b/internal/auth.go index ba43fadd..98892481 100644 --- a/internal/auth.go +++ b/internal/auth.go @@ -118,7 +118,6 @@ func ValidateDomains(user string, domains CommaSeparatedList) bool { return true } } - return false } @@ -197,23 +196,31 @@ func ClearCookie(r *http.Request) *http.Cookie { } } +func buildCSRFCookieName(nonce string) string { + return config.CSRFCookieName + "_" + nonce[:6] +} + // MakeCSRFCookie makes a csrf cookie (used during login only) +// +// Note, CSRF cookies live shorter than auth cookies, a fixed 1h. +// That's because some CSRF cookies may belong to auth flows that don't complete +// and thus may not get cleared by ClearCookie. func MakeCSRFCookie(r *http.Request, nonce string) *http.Cookie { return &http.Cookie{ - Name: config.CSRFCookieName, + Name: buildCSRFCookieName(nonce), Value: nonce, Path: "/", Domain: csrfCookieDomain(r), HttpOnly: true, Secure: !config.InsecureCookie, - Expires: cookieExpiry(), + Expires: time.Now().Local().Add(time.Hour * 1), } } // ClearCSRFCookie makes an expired csrf cookie to clear csrf cookie -func ClearCSRFCookie(r *http.Request) *http.Cookie { +func ClearCSRFCookie(r *http.Request, c *http.Cookie) *http.Cookie { return &http.Cookie{ - Name: config.CSRFCookieName, + Name: c.Name, Value: "", Path: "/", Domain: csrfCookieDomain(r), @@ -223,18 +230,18 @@ func ClearCSRFCookie(r *http.Request) *http.Cookie { } } -// ValidateCSRFCookie validates the csrf cookie against state -func ValidateCSRFCookie(r *http.Request, c *http.Cookie) (valid bool, provider string, redirect string, err error) { - state := r.URL.Query().Get("state") +// FindCSRFCookie extracts the CSRF cookie from the request based on state. +func FindCSRFCookie(r *http.Request, state string) (c *http.Cookie, err error) { + // Check for CSRF cookie + return r.Cookie(buildCSRFCookieName(state)) +} +// ValidateCSRFCookie validates the csrf cookie against state +func ValidateCSRFCookie(c *http.Cookie, state string) (valid bool, provider string, redirect string, err error) { if len(c.Value) != 32 { return false, "", "", errors.New("Invalid CSRF cookie value") } - if len(state) < 34 { - return false, "", "", errors.New("Invalid CSRF state value") - } - // Check nonce match if c.Value != state[:32] { return false, "", "", errors.New("CSRF cookie does not match state") @@ -256,6 +263,14 @@ func MakeState(r *http.Request, p provider.Provider, nonce string) string { return fmt.Sprintf("%s:%s:%s", nonce, p.Name(), returnUrl(r)) } +// ValidateState checks whether the state is of right length. +func ValidateState(state string) error { + if len(state) < 34 { + return errors.New("Invalid CSRF state value") + } + return nil +} + // Nonce generates a random nonce func Nonce() (error, string) { nonce := make([]byte, 16) diff --git a/internal/auth_test.go b/internal/auth_test.go index 5b0bedaf..0f000603 100644 --- a/internal/auth_test.go +++ b/internal/auth_test.go @@ -61,29 +61,29 @@ func TestAuthValidateCookie(t *testing.T) { assert.Equal("test@test.com", email, "valid request should return user email") } -func TestAuthValidateEmail(t *testing.T) { +func TestAuthValidateUser(t *testing.T) { assert := assert.New(t) config, _ = NewConfig([]string{}) // Should allow any with no whitelist/domain is specified - v := ValidateEmail("test@test.com", "default") + v := ValidateUser("test@test.com", "default") assert.True(v, "should allow any domain if email domain is not defined") - v = ValidateEmail("one@two.com", "default") + v = ValidateUser("one@two.com", "default") assert.True(v, "should allow any domain if email domain is not defined") // Should allow matching domain config.Domains = []string{"test.com"} - v = ValidateEmail("one@two.com", "default") + v = ValidateUser("one@two.com", "default") assert.False(v, "should not allow user from another domain") - v = ValidateEmail("test@test.com", "default") + v = ValidateUser("test@test.com", "default") assert.True(v, "should allow user from allowed domain") // Should allow matching whitelisted email address config.Domains = []string{} config.Whitelist = []string{"test@test.com"} - v = ValidateEmail("one@two.com", "default") + v = ValidateUser("one@two.com", "default") assert.False(v, "should not allow user not in whitelist") - v = ValidateEmail("test@test.com", "default") + v = ValidateUser("test@test.com", "default") assert.True(v, "should allow user in whitelist") // Should allow only matching email address when @@ -91,11 +91,11 @@ func TestAuthValidateEmail(t *testing.T) { config.Domains = []string{"example.com"} config.Whitelist = []string{"test@test.com"} config.MatchWhitelistOrDomain = false - v = ValidateEmail("one@two.com", "default") + v = ValidateUser("one@two.com", "default") assert.False(v, "should not allow user not in either") - v = ValidateEmail("test@example.com", "default") + v = ValidateUser("test@example.com", "default") assert.False(v, "should not allow user from allowed domain") - v = ValidateEmail("test@test.com", "default") + v = ValidateUser("test@test.com", "default") assert.True(v, "should allow user in whitelist") // Should allow either matching domain or email address when @@ -103,11 +103,11 @@ func TestAuthValidateEmail(t *testing.T) { config.Domains = []string{"example.com"} config.Whitelist = []string{"test@test.com"} config.MatchWhitelistOrDomain = true - v = ValidateEmail("one@two.com", "default") + v = ValidateUser("one@two.com", "default") assert.False(v, "should not allow user not in either") - v = ValidateEmail("test@example.com", "default") + v = ValidateUser("test@example.com", "default") assert.True(v, "should allow user from allowed domain") - v = ValidateEmail("test@test.com", "default") + v = ValidateUser("test@test.com", "default") assert.True(v, "should allow user in whitelist") // Rule testing @@ -117,11 +117,11 @@ func TestAuthValidateEmail(t *testing.T) { config.Whitelist = []string{"test@test.com"} config.Rules = map[string]*Rule{"test": NewRule()} config.MatchWhitelistOrDomain = true - v = ValidateEmail("one@two.com", "test") + v = ValidateUser("one@two.com", "test") assert.False(v, "should not allow user not in either") - v = ValidateEmail("test@example.com", "test") + v = ValidateUser("test@example.com", "test") assert.True(v, "should allow user from allowed global domain") - v = ValidateEmail("test@test.com", "test") + v = ValidateUser("test@test.com", "test") assert.True(v, "should allow user in global whitelist") // Should allow matching domain in rule @@ -131,11 +131,11 @@ func TestAuthValidateEmail(t *testing.T) { config.Rules = map[string]*Rule{"test": rule} rule.Domains = []string{"testrule.com"} config.MatchWhitelistOrDomain = false - v = ValidateEmail("one@two.com", "test") + v = ValidateUser("one@two.com", "test") assert.False(v, "should not allow user from another domain") - v = ValidateEmail("one@testglobal.com", "test") + v = ValidateUser("one@testglobal.com", "test") assert.False(v, "should not allow user from global domain") - v = ValidateEmail("test@testrule.com", "test") + v = ValidateUser("test@testrule.com", "test") assert.True(v, "should allow user from allowed domain") // Should allow matching whitelist in rule @@ -145,11 +145,11 @@ func TestAuthValidateEmail(t *testing.T) { config.Rules = map[string]*Rule{"test": rule} rule.Whitelist = []string{"test@testrule.com"} config.MatchWhitelistOrDomain = false - v = ValidateEmail("one@two.com", "test") + v = ValidateUser("one@two.com", "test") assert.False(v, "should not allow user from another domain") - v = ValidateEmail("test@testglobal.com", "test") + v = ValidateUser("test@testglobal.com", "test") assert.False(v, "should not allow user from global domain") - v = ValidateEmail("test@testrule.com", "test") + v = ValidateUser("test@testrule.com", "test") assert.True(v, "should allow user from allowed domain") // Should allow only matching email address when @@ -161,15 +161,15 @@ func TestAuthValidateEmail(t *testing.T) { rule.Domains = []string{"examplerule.com"} rule.Whitelist = []string{"test@testrule.com"} config.MatchWhitelistOrDomain = false - v = ValidateEmail("one@two.com", "test") + v = ValidateUser("one@two.com", "test") assert.False(v, "should not allow user not in either") - v = ValidateEmail("test@testglobal.com", "test") + v = ValidateUser("test@testglobal.com", "test") assert.False(v, "should not allow user in global whitelist") - v = ValidateEmail("test@exampleglobal.com", "test") + v = ValidateUser("test@exampleglobal.com", "test") assert.False(v, "should not allow user from global domain") - v = ValidateEmail("test@examplerule.com", "test") + v = ValidateUser("test@examplerule.com", "test") assert.False(v, "should not allow user from allowed domain") - v = ValidateEmail("test@testrule.com", "test") + v = ValidateUser("test@testrule.com", "test") assert.True(v, "should allow user in whitelist") // Should allow either matching domain or email address when @@ -181,15 +181,15 @@ func TestAuthValidateEmail(t *testing.T) { rule.Domains = []string{"examplerule.com"} rule.Whitelist = []string{"test@testrule.com"} config.MatchWhitelistOrDomain = true - v = ValidateEmail("one@two.com", "test") + v = ValidateUser("one@two.com", "test") assert.False(v, "should not allow user not in either") - v = ValidateEmail("test@testglobal.com", "test") + v = ValidateUser("test@testglobal.com", "test") assert.False(v, "should not allow user in global whitelist") - v = ValidateEmail("test@exampleglobal.com", "test") + v = ValidateUser("test@exampleglobal.com", "test") assert.False(v, "should not allow user from global domain") - v = ValidateEmail("test@examplerule.com", "test") + v = ValidateUser("test@examplerule.com", "test") assert.True(v, "should allow user from allowed domain") - v = ValidateEmail("test@testrule.com", "test") + v = ValidateUser("test@testrule.com", "test") assert.True(v, "should allow user in whitelist") } diff --git a/internal/server.go b/internal/server.go index a0bf2902..55288659 100644 --- a/internal/server.go +++ b/internal/server.go @@ -102,7 +102,7 @@ func (s *Server) AuthHandler(providerName, rule string) http.HandlerFunc { } // Validate user - valid := ValidateUser(email, rule) + valid := ValidateUser(user, rule) if !valid { logger.WithField("user", user).Warn("Invalid user") http.Error(w, fmt.Sprintf("User '%s' is not authorized", user), 401)