diff --git a/auth/oauth/oauth.go b/auth/oauth/oauth.go index e5dd0c4..9a68c35 100644 --- a/auth/oauth/oauth.go +++ b/auth/oauth/oauth.go @@ -46,11 +46,14 @@ func (a *OAuth) Login() http.HandlerFunc { } // State contain the provider and the csrf token. state := fmt.Sprintf("%s,%s", token, p) + authCodeURL := provider.AuthCodeURL(state) + fmt.Println(authCodeURL) + http.SetCookie(w, cookie) http.Redirect( w, r, - provider.AuthCodeURL(state), + authCodeURL, http.StatusFound, ) } @@ -82,6 +85,12 @@ func (a *OAuth) CallBack() http.HandlerFunc { return } + errorDescription := val.Get("error_description") + if errorDescription != "" { + http.Error(w, errorDescription, http.StatusUnauthorized) + return + } + expectedCSRF, err := r.Cookie("csrf_token") if err == http.ErrNoCookie { http.Error(w, "no csrf cookie error", http.StatusUnauthorized) @@ -100,19 +109,19 @@ func (a *OAuth) CallBack() http.HandlerFunc { oauth2Token, err := provider.Exchange(r.Context(), code) if err != nil { - http.Error(w, err.Error(), http.StatusUnauthorized) + http.Error(w, fmt.Sprintf("exchange: %v", err), http.StatusUnauthorized) return } userID, userName, err := provider.GetIdentity(r.Context(), oauth2Token) if err != nil { - http.Error(w, err.Error(), http.StatusUnauthorized) + http.Error(w, fmt.Sprintf("id: %v", err), http.StatusUnauthorized) return } token, err := a.JWTSecret.GenerateToken(userID, userName, strings.ToLower(p)) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + http.Error(w, fmt.Sprintf("gen token: %v", err), http.StatusInternalServerError) return }