Skip to content

Commit

Permalink
Add auth middleware and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
theandrew168 committed Sep 3, 2024
1 parent 3a4435c commit c4dd426
Show file tree
Hide file tree
Showing 5 changed files with 407 additions and 12 deletions.
8 changes: 7 additions & 1 deletion backend/web/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ func Handler(
) http.Handler {
mux := http.NewServeMux()

accountRequired := middleware.AccountRequired()
// adminRequired := middleware.Chain(accountRequired, middleware.AdminRequired())

// Host prometheus metrics on "/metrics".
mux.Handle("GET /metrics", promhttp.Handler())

Expand All @@ -42,13 +45,15 @@ func Handler(

// The main application routes start here.
mux.Handle("GET /{$}", page.HandleIndexPage(find))
mux.Handle("GET /blogs", page.HandleBlogsPage(find))

mux.Handle("GET /register", page.HandleRegisterPage())
mux.Handle("POST /register", page.HandleRegisterForm(repo))
mux.Handle("GET /signin", page.HandleSigninPage())
mux.Handle("POST /signin", page.HandleSigninForm(repo))
mux.Handle("POST /signout", page.HandleSignoutForm(repo))

mux.Handle("GET /blogs", accountRequired(page.HandleBlogsPage(find)))

// Requests that don't match any of the above handlers get a 404.
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(404)
Expand All @@ -60,5 +65,6 @@ func Handler(
middleware.RecoverPanic(),
middleware.SecureHeaders(),
middleware.LimitRequestBodySize(),
middleware.Authenticate(repo),
)
}
81 changes: 81 additions & 0 deletions backend/web/middleware/auth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package middleware

import (
"errors"
"net/http"

"github.com/theandrew168/bloggulus/backend/postgres"
"github.com/theandrew168/bloggulus/backend/repository"
"github.com/theandrew168/bloggulus/backend/web/util"
)

func Authenticate(repo *repository.Repository) Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {

// Check for a sessionID cookie.
sessionID, err := r.Cookie(util.SessionCookieName)
if err != nil {
next.ServeHTTP(w, r)
return
}

// Lookup the account linked to the session.
account, err := repo.Account().ReadBySessionID(sessionID.Value)
if err != nil {
// If the user has an invalid / expired session cookie, delete it.
if errors.Is(err, postgres.ErrNotFound) {
cookie := util.NewExpiredCookie(util.SessionCookieName)
http.SetCookie(w, &cookie)

next.ServeHTTP(w, r)
return
}

http.Error(w, err.Error(), 500)
return
}

// If it exists, attach the account to the request context.
r = util.ContextSetAccount(r, account)

next.ServeHTTP(w, r)
})
}
}

func AccountRequired() Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// If the request context has no account, then the user is not signed in (redirect).
_, ok := util.ContextGetAccount(r)
if !ok {
http.Redirect(w, r, "/signin", http.StatusSeeOther)
return
}

next.ServeHTTP(w, r)
})
}
}

func AdminRequired() Middleware {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// If the request context has no account, then the user is not signed in (redirect).
account, ok := util.ContextGetAccount(r)
if !ok {
http.Redirect(w, r, "/signin", http.StatusSeeOther)
return
}

// If the account exists but is not an admin account, show a 403 Forbidden page.
if !account.IsAdmin() {
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}

next.ServeHTTP(w, r)
})
}
}
279 changes: 279 additions & 0 deletions backend/web/middleware/auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
package middleware_test

import (
"context"
"net/http"
"net/http/httptest"
"testing"

"github.com/theandrew168/bloggulus/backend/test"
"github.com/theandrew168/bloggulus/backend/web/middleware"
"github.com/theandrew168/bloggulus/backend/web/util"
)

func TestAuthenticate(t *testing.T) {
t.Parallel()

repo, closer := test.NewRepository(t)
defer closer()

account, _ := test.CreateAccount(t, repo)
_, sessionID := test.CreateSession(t, repo, account)
sessionCookie := util.NewSessionCookie(util.SessionCookieName, sessionID)

w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)
r.AddCookie(&sessionCookie)

next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
got, ok := util.ContextGetAccount(r)
test.AssertEqual(t, ok, true)
test.AssertEqual(t, got.ID(), account.ID())
})

h := middleware.Use(next,
middleware.Authenticate(repo),
)
h.ServeHTTP(w, r)
}

func TestAuthenticateNoSession(t *testing.T) {
t.Parallel()

repo, closer := test.NewRepository(t)
defer closer()

w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)

next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})

h := middleware.Use(next, middleware.Authenticate(repo))
h.ServeHTTP(w, r)

rr := w.Result()
test.AssertEqual(t, rr.StatusCode, http.StatusOK)
}

func TestAuthenticateInvalidSession(t *testing.T) {
t.Parallel()

repo, closer := test.NewRepository(t)
defer closer()

sessionCookie := util.NewSessionCookie(util.SessionCookieName, "foobar")

w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)
r.AddCookie(&sessionCookie)

next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})

h := middleware.Use(next, middleware.Authenticate(repo))
h.ServeHTTP(w, r)

rr := w.Result()
test.AssertEqual(t, rr.StatusCode, http.StatusOK)
}

func TestAccountRequired(t *testing.T) {
t.Parallel()

repo, closer := test.NewRepository(t)
defer closer()

account, _ := test.CreateAccount(t, repo)
_, sessionID := test.CreateSession(t, repo, account)
sessionCookie := util.NewSessionCookie(util.SessionCookieName, sessionID)

w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)
r.AddCookie(&sessionCookie)

next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
got, ok := util.ContextGetAccount(r)
test.AssertEqual(t, ok, true)
test.AssertEqual(t, got.ID(), account.ID())
})

h := middleware.Use(next,
middleware.Authenticate(repo),
middleware.AccountRequired(),
)
h.ServeHTTP(w, r)

rr := w.Result()
test.AssertEqual(t, rr.StatusCode, http.StatusOK)
}

func TestAccountRequiredNoSession(t *testing.T) {
t.Parallel()

repo, closer := test.NewRepository(t)
defer closer()

w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)

next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})

h := middleware.Use(next,
middleware.Authenticate(repo),
middleware.AccountRequired(),
)
h.ServeHTTP(w, r)

rr := w.Result()
test.AssertEqual(t, rr.StatusCode, http.StatusSeeOther)
test.AssertEqual(t, rr.Header.Get("Location"), "/signin")
}

func TestAccountRequiredInvalidSession(t *testing.T) {
t.Parallel()

repo, closer := test.NewRepository(t)
defer closer()

sessionCookie := util.NewSessionCookie(util.SessionCookieName, "foobar")

w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)
r.AddCookie(&sessionCookie)

next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})

h := middleware.Use(next,
middleware.Authenticate(repo),
middleware.AccountRequired(),
)
h.ServeHTTP(w, r)

rr := w.Result()
test.AssertEqual(t, rr.StatusCode, http.StatusSeeOther)
test.AssertEqual(t, rr.Header.Get("Location"), "/signin")
}

func TestAdminRequired(t *testing.T) {
t.Parallel()

repo, closer := test.NewRepository(t)
defer closer()

account, _ := test.CreateAccount(t, repo)
_, sessionID := test.CreateSession(t, repo, account)
sessionCookie := util.NewSessionCookie(util.SessionCookieName, sessionID)

// Make the account an admin via manual SQL.
err := repo.Exec(context.Background(), "UPDATE account SET is_admin = TRUE WHERE id = $1", account.ID())
test.AssertNilError(t, err)

w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)
r.AddCookie(&sessionCookie)

next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
got, ok := util.ContextGetAccount(r)
test.AssertEqual(t, ok, true)
test.AssertEqual(t, got.ID(), account.ID())
})

h := middleware.Use(next,
middleware.Authenticate(repo),
middleware.AccountRequired(),
middleware.AdminRequired(),
)
h.ServeHTTP(w, r)

rr := w.Result()
test.AssertEqual(t, rr.StatusCode, http.StatusOK)
}

func TestAdminRequiredNoSession(t *testing.T) {
t.Parallel()

repo, closer := test.NewRepository(t)
defer closer()

w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)

next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})

h := middleware.Use(next,
middleware.Authenticate(repo),
middleware.AccountRequired(),
middleware.AdminRequired(),
)
h.ServeHTTP(w, r)

rr := w.Result()
test.AssertEqual(t, rr.StatusCode, http.StatusSeeOther)
test.AssertEqual(t, rr.Header.Get("Location"), "/signin")
}

func TestAdminRequiredInvalidSession(t *testing.T) {
t.Parallel()

repo, closer := test.NewRepository(t)
defer closer()

sessionCookie := util.NewSessionCookie(util.SessionCookieName, "foobar")

w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)
r.AddCookie(&sessionCookie)

next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})

h := middleware.Use(next,
middleware.Authenticate(repo),
middleware.AccountRequired(),
middleware.AdminRequired(),
)
h.ServeHTTP(w, r)

rr := w.Result()
test.AssertEqual(t, rr.StatusCode, http.StatusSeeOther)
test.AssertEqual(t, rr.Header.Get("Location"), "/signin")
}

func TestAdminRequiredNotAdmin(t *testing.T) {
t.Parallel()

repo, closer := test.NewRepository(t)
defer closer()

account, _ := test.CreateAccount(t, repo)
_, sessionID := test.CreateSession(t, repo, account)
sessionCookie := util.NewSessionCookie(util.SessionCookieName, sessionID)

w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)
r.AddCookie(&sessionCookie)

next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})

h := middleware.Use(next,
middleware.Authenticate(repo),
middleware.AccountRequired(),
middleware.AdminRequired(),
)
h.ServeHTTP(w, r)

rr := w.Result()
test.AssertEqual(t, rr.StatusCode, http.StatusForbidden)
}
Loading

0 comments on commit c4dd426

Please sign in to comment.