From a9765ac6b784a28633c8f3f063ed6d9314918c1d Mon Sep 17 00:00:00 2001 From: jojo Date: Thu, 28 Sep 2023 22:25:52 -0300 Subject: [PATCH] :bug: chore: add tests for callback function --- api/authhandler.go | 14 ++++--------- api/authhandler_test.go | 46 +++++++++++++++++++++++++++++++++++++++++ api/token.go | 12 ++++++++++- 3 files changed, 61 insertions(+), 11 deletions(-) diff --git a/api/authhandler.go b/api/authhandler.go index a915276..4546e8e 100644 --- a/api/authhandler.go +++ b/api/authhandler.go @@ -3,7 +3,6 @@ package api import ( "encoding/json" - "errors" "net/http" "log/slog" @@ -13,8 +12,6 @@ import ( "golang.org/x/oauth2" ) -// Auth endpoints - // TODO(JOJO): randomize this var randState = "random" @@ -48,18 +45,15 @@ func login(w http.ResponseWriter, r *http.Request, a Auth) { func callback(w http.ResponseWriter, r *http.Request, a Auth) { state := r.FormValue("state") - if state == "" { - sendErr(w, http.StatusBadRequest, errors.New("missing state")) - return - } - if state != randState { - sendErr(w, http.StatusBadRequest, errors.New("invalid state")) + + if state == "" || state != randState { + sendErr(w, http.StatusBadRequest, Error{"invalid_state", "invalid state query parameter"}) return } code := r.FormValue("code") if code == "" { - sendErr(w, http.StatusBadRequest, errors.New("missing code")) + sendErr(w, http.StatusBadRequest, Error{"invalid_code", "invalid code query parameter"}) return } diff --git a/api/authhandler_test.go b/api/authhandler_test.go index da1985a..46bbf96 100644 --- a/api/authhandler_test.go +++ b/api/authhandler_test.go @@ -1,6 +1,7 @@ package api import ( + "context" "encoding/json" "net/http" "net/http/httptest" @@ -69,3 +70,48 @@ func TestLogin(t *testing.T) { assert(t, resp.Code, http.StatusFound) } + +type mockOAuthConfig struct{} + +func (m *mockOAuthConfig) AuthCodeURL(_ string, _ ...oauth2.AuthCodeOption) string { + return "" +} + +func (m *mockOAuthConfig) Client(_ context.Context, _ *oauth2.Token) *http.Client { + return &http.Client{} +} + +func (m *mockOAuthConfig) Exchange(_ context.Context, _ string, _ ...oauth2.AuthCodeOption) (*oauth2.Token, error) { + return &oauth2.Token{}, nil +} + +func TestCallback(t *testing.T) { + auth := Auth{ + GoogleOAuthConfig: &mockOAuthConfig{}, + Domain: "example.com", + } + r := chi.NewRouter() + RegisterAuthHandler(r, auth) + + req := httptest.NewRequest(http.MethodGet, "/callback?state="+randState+"&code=mock-code", nil) + resp := httptest.NewRecorder() + r.ServeHTTP(resp, req) + + assert(t, resp.Code, http.StatusSeeOther) + + cookies := resp.Result().Cookies() + assert(t, len(cookies), 1) + assert(t, cookies[0].Name, "jwt") +} + +func TestCallback_InvalidParam(t *testing.T) { + r := chi.NewRouter() + RegisterAuthHandler(r, Auth{}) + + req := httptest.NewRequest(http.MethodGet, "/callback", nil) + + resp := httptest.NewRecorder() + r.ServeHTTP(resp, req) + + assert(t, resp.Code, http.StatusBadRequest) +} diff --git a/api/token.go b/api/token.go index ea4826e..9cd6424 100644 --- a/api/token.go +++ b/api/token.go @@ -1,11 +1,21 @@ package api import ( + "context" + "net/http" + "github.com/go-chi/jwtauth/v5" "github.com/golang-jwt/jwt" "golang.org/x/oauth2" ) +// OAuth2Interface is an interface for the oauth2.Config struct to be able to mock it and test the callback behaviour. +type OAuth2Interface interface { + Exchange(ctx context.Context, code string, opts ...oauth2.AuthCodeOption) (*oauth2.Token, error) + Client(ctx context.Context, t *oauth2.Token) *http.Client + AuthCodeURL(state string, opts ...oauth2.AuthCodeOption) string +} + // Auth have the configuration for Auth endpoints. type Auth struct { // Google OAuth2 @@ -18,7 +28,7 @@ type Auth struct { // JWT secret key JWTSecretKey string // Google OAuth2 config struct - GoogleOAuthConfig *oauth2.Config + GoogleOAuthConfig OAuth2Interface // Access type AccessType string // offline(for local) or online(for production) }