-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3a4435c
commit c4dd426
Showing
5 changed files
with
407 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
package middleware | ||
|
||
import ( | ||
"errors" | ||
"net/http" | ||
|
||
"github.com/theandrew168/bloggulus/backend/postgres" | ||
"github.com/theandrew168/bloggulus/backend/repository" | ||
"github.com/theandrew168/bloggulus/backend/web/util" | ||
) | ||
|
||
func Authenticate(repo *repository.Repository) Middleware { | ||
return func(next http.Handler) http.Handler { | ||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
|
||
// Check for a sessionID cookie. | ||
sessionID, err := r.Cookie(util.SessionCookieName) | ||
if err != nil { | ||
next.ServeHTTP(w, r) | ||
return | ||
} | ||
|
||
// Lookup the account linked to the session. | ||
account, err := repo.Account().ReadBySessionID(sessionID.Value) | ||
if err != nil { | ||
// If the user has an invalid / expired session cookie, delete it. | ||
if errors.Is(err, postgres.ErrNotFound) { | ||
cookie := util.NewExpiredCookie(util.SessionCookieName) | ||
http.SetCookie(w, &cookie) | ||
|
||
next.ServeHTTP(w, r) | ||
return | ||
} | ||
|
||
http.Error(w, err.Error(), 500) | ||
return | ||
} | ||
|
||
// If it exists, attach the account to the request context. | ||
r = util.ContextSetAccount(r, account) | ||
|
||
next.ServeHTTP(w, r) | ||
}) | ||
} | ||
} | ||
|
||
func AccountRequired() Middleware { | ||
return func(next http.Handler) http.Handler { | ||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
// If the request context has no account, then the user is not signed in (redirect). | ||
_, ok := util.ContextGetAccount(r) | ||
if !ok { | ||
http.Redirect(w, r, "/signin", http.StatusSeeOther) | ||
return | ||
} | ||
|
||
next.ServeHTTP(w, r) | ||
}) | ||
} | ||
} | ||
|
||
func AdminRequired() Middleware { | ||
return func(next http.Handler) http.Handler { | ||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
// If the request context has no account, then the user is not signed in (redirect). | ||
account, ok := util.ContextGetAccount(r) | ||
if !ok { | ||
http.Redirect(w, r, "/signin", http.StatusSeeOther) | ||
return | ||
} | ||
|
||
// If the account exists but is not an admin account, show a 403 Forbidden page. | ||
if !account.IsAdmin() { | ||
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) | ||
return | ||
} | ||
|
||
next.ServeHTTP(w, r) | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,279 @@ | ||
package middleware_test | ||
|
||
import ( | ||
"context" | ||
"net/http" | ||
"net/http/httptest" | ||
"testing" | ||
|
||
"github.com/theandrew168/bloggulus/backend/test" | ||
"github.com/theandrew168/bloggulus/backend/web/middleware" | ||
"github.com/theandrew168/bloggulus/backend/web/util" | ||
) | ||
|
||
func TestAuthenticate(t *testing.T) { | ||
t.Parallel() | ||
|
||
repo, closer := test.NewRepository(t) | ||
defer closer() | ||
|
||
account, _ := test.CreateAccount(t, repo) | ||
_, sessionID := test.CreateSession(t, repo, account) | ||
sessionCookie := util.NewSessionCookie(util.SessionCookieName, sessionID) | ||
|
||
w := httptest.NewRecorder() | ||
r := httptest.NewRequest("GET", "/", nil) | ||
r.AddCookie(&sessionCookie) | ||
|
||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
got, ok := util.ContextGetAccount(r) | ||
test.AssertEqual(t, ok, true) | ||
test.AssertEqual(t, got.ID(), account.ID()) | ||
}) | ||
|
||
h := middleware.Use(next, | ||
middleware.Authenticate(repo), | ||
) | ||
h.ServeHTTP(w, r) | ||
} | ||
|
||
func TestAuthenticateNoSession(t *testing.T) { | ||
t.Parallel() | ||
|
||
repo, closer := test.NewRepository(t) | ||
defer closer() | ||
|
||
w := httptest.NewRecorder() | ||
r := httptest.NewRequest("GET", "/", nil) | ||
|
||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
w.WriteHeader(http.StatusOK) | ||
}) | ||
|
||
h := middleware.Use(next, middleware.Authenticate(repo)) | ||
h.ServeHTTP(w, r) | ||
|
||
rr := w.Result() | ||
test.AssertEqual(t, rr.StatusCode, http.StatusOK) | ||
} | ||
|
||
func TestAuthenticateInvalidSession(t *testing.T) { | ||
t.Parallel() | ||
|
||
repo, closer := test.NewRepository(t) | ||
defer closer() | ||
|
||
sessionCookie := util.NewSessionCookie(util.SessionCookieName, "foobar") | ||
|
||
w := httptest.NewRecorder() | ||
r := httptest.NewRequest("GET", "/", nil) | ||
r.AddCookie(&sessionCookie) | ||
|
||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
w.WriteHeader(http.StatusOK) | ||
}) | ||
|
||
h := middleware.Use(next, middleware.Authenticate(repo)) | ||
h.ServeHTTP(w, r) | ||
|
||
rr := w.Result() | ||
test.AssertEqual(t, rr.StatusCode, http.StatusOK) | ||
} | ||
|
||
func TestAccountRequired(t *testing.T) { | ||
t.Parallel() | ||
|
||
repo, closer := test.NewRepository(t) | ||
defer closer() | ||
|
||
account, _ := test.CreateAccount(t, repo) | ||
_, sessionID := test.CreateSession(t, repo, account) | ||
sessionCookie := util.NewSessionCookie(util.SessionCookieName, sessionID) | ||
|
||
w := httptest.NewRecorder() | ||
r := httptest.NewRequest("GET", "/", nil) | ||
r.AddCookie(&sessionCookie) | ||
|
||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
got, ok := util.ContextGetAccount(r) | ||
test.AssertEqual(t, ok, true) | ||
test.AssertEqual(t, got.ID(), account.ID()) | ||
}) | ||
|
||
h := middleware.Use(next, | ||
middleware.Authenticate(repo), | ||
middleware.AccountRequired(), | ||
) | ||
h.ServeHTTP(w, r) | ||
|
||
rr := w.Result() | ||
test.AssertEqual(t, rr.StatusCode, http.StatusOK) | ||
} | ||
|
||
func TestAccountRequiredNoSession(t *testing.T) { | ||
t.Parallel() | ||
|
||
repo, closer := test.NewRepository(t) | ||
defer closer() | ||
|
||
w := httptest.NewRecorder() | ||
r := httptest.NewRequest("GET", "/", nil) | ||
|
||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
w.WriteHeader(http.StatusOK) | ||
}) | ||
|
||
h := middleware.Use(next, | ||
middleware.Authenticate(repo), | ||
middleware.AccountRequired(), | ||
) | ||
h.ServeHTTP(w, r) | ||
|
||
rr := w.Result() | ||
test.AssertEqual(t, rr.StatusCode, http.StatusSeeOther) | ||
test.AssertEqual(t, rr.Header.Get("Location"), "/signin") | ||
} | ||
|
||
func TestAccountRequiredInvalidSession(t *testing.T) { | ||
t.Parallel() | ||
|
||
repo, closer := test.NewRepository(t) | ||
defer closer() | ||
|
||
sessionCookie := util.NewSessionCookie(util.SessionCookieName, "foobar") | ||
|
||
w := httptest.NewRecorder() | ||
r := httptest.NewRequest("GET", "/", nil) | ||
r.AddCookie(&sessionCookie) | ||
|
||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
w.WriteHeader(http.StatusOK) | ||
}) | ||
|
||
h := middleware.Use(next, | ||
middleware.Authenticate(repo), | ||
middleware.AccountRequired(), | ||
) | ||
h.ServeHTTP(w, r) | ||
|
||
rr := w.Result() | ||
test.AssertEqual(t, rr.StatusCode, http.StatusSeeOther) | ||
test.AssertEqual(t, rr.Header.Get("Location"), "/signin") | ||
} | ||
|
||
func TestAdminRequired(t *testing.T) { | ||
t.Parallel() | ||
|
||
repo, closer := test.NewRepository(t) | ||
defer closer() | ||
|
||
account, _ := test.CreateAccount(t, repo) | ||
_, sessionID := test.CreateSession(t, repo, account) | ||
sessionCookie := util.NewSessionCookie(util.SessionCookieName, sessionID) | ||
|
||
// Make the account an admin via manual SQL. | ||
err := repo.Exec(context.Background(), "UPDATE account SET is_admin = TRUE WHERE id = $1", account.ID()) | ||
test.AssertNilError(t, err) | ||
|
||
w := httptest.NewRecorder() | ||
r := httptest.NewRequest("GET", "/", nil) | ||
r.AddCookie(&sessionCookie) | ||
|
||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
got, ok := util.ContextGetAccount(r) | ||
test.AssertEqual(t, ok, true) | ||
test.AssertEqual(t, got.ID(), account.ID()) | ||
}) | ||
|
||
h := middleware.Use(next, | ||
middleware.Authenticate(repo), | ||
middleware.AccountRequired(), | ||
middleware.AdminRequired(), | ||
) | ||
h.ServeHTTP(w, r) | ||
|
||
rr := w.Result() | ||
test.AssertEqual(t, rr.StatusCode, http.StatusOK) | ||
} | ||
|
||
func TestAdminRequiredNoSession(t *testing.T) { | ||
t.Parallel() | ||
|
||
repo, closer := test.NewRepository(t) | ||
defer closer() | ||
|
||
w := httptest.NewRecorder() | ||
r := httptest.NewRequest("GET", "/", nil) | ||
|
||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
w.WriteHeader(http.StatusOK) | ||
}) | ||
|
||
h := middleware.Use(next, | ||
middleware.Authenticate(repo), | ||
middleware.AccountRequired(), | ||
middleware.AdminRequired(), | ||
) | ||
h.ServeHTTP(w, r) | ||
|
||
rr := w.Result() | ||
test.AssertEqual(t, rr.StatusCode, http.StatusSeeOther) | ||
test.AssertEqual(t, rr.Header.Get("Location"), "/signin") | ||
} | ||
|
||
func TestAdminRequiredInvalidSession(t *testing.T) { | ||
t.Parallel() | ||
|
||
repo, closer := test.NewRepository(t) | ||
defer closer() | ||
|
||
sessionCookie := util.NewSessionCookie(util.SessionCookieName, "foobar") | ||
|
||
w := httptest.NewRecorder() | ||
r := httptest.NewRequest("GET", "/", nil) | ||
r.AddCookie(&sessionCookie) | ||
|
||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
w.WriteHeader(http.StatusOK) | ||
}) | ||
|
||
h := middleware.Use(next, | ||
middleware.Authenticate(repo), | ||
middleware.AccountRequired(), | ||
middleware.AdminRequired(), | ||
) | ||
h.ServeHTTP(w, r) | ||
|
||
rr := w.Result() | ||
test.AssertEqual(t, rr.StatusCode, http.StatusSeeOther) | ||
test.AssertEqual(t, rr.Header.Get("Location"), "/signin") | ||
} | ||
|
||
func TestAdminRequiredNotAdmin(t *testing.T) { | ||
t.Parallel() | ||
|
||
repo, closer := test.NewRepository(t) | ||
defer closer() | ||
|
||
account, _ := test.CreateAccount(t, repo) | ||
_, sessionID := test.CreateSession(t, repo, account) | ||
sessionCookie := util.NewSessionCookie(util.SessionCookieName, sessionID) | ||
|
||
w := httptest.NewRecorder() | ||
r := httptest.NewRequest("GET", "/", nil) | ||
r.AddCookie(&sessionCookie) | ||
|
||
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
w.WriteHeader(http.StatusOK) | ||
}) | ||
|
||
h := middleware.Use(next, | ||
middleware.Authenticate(repo), | ||
middleware.AccountRequired(), | ||
middleware.AdminRequired(), | ||
) | ||
h.ServeHTTP(w, r) | ||
|
||
rr := w.Result() | ||
test.AssertEqual(t, rr.StatusCode, http.StatusForbidden) | ||
} |
Oops, something went wrong.