Skip to content

Commit

Permalink
refactor: move things related to jwt to jwt
Browse files Browse the repository at this point in the history
  • Loading branch information
Darkness4 committed Jan 24, 2024
1 parent 5834905 commit 60c5e59
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 80 deletions.
62 changes: 11 additions & 51 deletions auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
package auth

import (
"context"
"fmt"
"net/http"
"net/url"
Expand All @@ -14,13 +13,6 @@ import (
"github.com/rs/zerolog/log"
)

const (
// TokenCookieKey is the key of the cookie stored in the context.
TokenCookieKey = "session_token"
)

type claimsContextKey struct{}

// Auth is a service that provides HTTP handlers and middlewares used for authentication.
type Auth struct {
JWTSecret jwt.Secret
Expand Down Expand Up @@ -123,7 +115,7 @@ func (a *Auth) CallBack() http.HandlerFunc {
}

cookie := &http.Cookie{
Name: TokenCookieKey,
Name: jwt.TokenCookieKey,
Value: token,
Path: "/",
Expires: time.Now().Add(jwt.ExpiresDuration),
Expand All @@ -135,48 +127,16 @@ func (a *Auth) CallBack() http.HandlerFunc {
}

// Logout removes session cookies and redirect to home.
func (a *Auth) Logout() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie(TokenCookieKey)
if err != nil {
// Ignore error. Cookie doesn't exists.
http.Redirect(w, r, "/", http.StatusSeeOther)
return
}
cookie.Value = ""
cookie.Path = "/"
cookie.Expires = time.Now().Add(-1 * time.Hour)
http.SetCookie(w, cookie)
func Logout(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie(jwt.TokenCookieKey)
if err != nil {
// Ignore error. Cookie doesn't exists.
http.Redirect(w, r, "/", http.StatusSeeOther)
return
}
}

// Middleware is an authentication guard for HTTP servers.
func (a *Auth) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get the JWT token from the request header
cookie, err := r.Cookie(TokenCookieKey)
if err != nil {
next.ServeHTTP(w, r)
return
}

// Verify the JWT token
claims, err := a.JWTSecret.VerifyToken(cookie.Value)
if err != nil {
log.Error().Err(err).Msg("token verification failed")
next.ServeHTTP(w, r)
return
}

// Store the claims in the request context for further use
ctx := context.WithValue(r.Context(), claimsContextKey{}, *claims)
next.ServeHTTP(w, r.WithContext(ctx))
})
}

// GetClaimsFromRequest is a helper function to fetch the JWT session token from an HTTP request.
func GetClaimsFromRequest(r *http.Request) (claims jwt.Claims, ok bool) {
claims, ok = r.Context().Value(claimsContextKey{}).(jwt.Claims)
return claims, ok
cookie.Value = ""
cookie.Path = "/"
cookie.Expires = time.Now().Add(-1 * time.Hour)
http.SetCookie(w, cookie)
http.Redirect(w, r, "/", http.StatusSeeOther)
}
13 changes: 6 additions & 7 deletions auth/webauthn/webauthn.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"net/http"
"time"

"github.com/Darkness4/auth-htmx/auth"
"github.com/Darkness4/auth-htmx/auth/webauthn/session"
"github.com/Darkness4/auth-htmx/database/user"
"github.com/Darkness4/auth-htmx/jwt"
Expand Down Expand Up @@ -162,7 +161,7 @@ func (s *Service) FinishLogin() http.HandlerFunc {
}

cookie := &http.Cookie{
Name: auth.TokenCookieKey,
Name: jwt.TokenCookieKey,
Value: token,
Path: "/",
Expires: time.Now().Add(jwt.ExpiresDuration),
Expand Down Expand Up @@ -290,7 +289,7 @@ func (s *Service) FinishRegistration() http.HandlerFunc {
}

cookie := &http.Cookie{
Name: auth.TokenCookieKey,
Name: jwt.TokenCookieKey,
Value: token,
Path: "/",
Expires: time.Now().Add(jwt.ExpiresDuration),
Expand All @@ -309,7 +308,7 @@ func (s *Service) FinishRegistration() http.HandlerFunc {
// Compared to BeginRegistration, BeginAddDevice uses the JWT to allow the registration.
func (s *Service) BeginAddDevice() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
claims, ok := auth.GetClaimsFromRequest(r)
claims, ok := jwt.GetClaimsFromRequest(r)
if !ok {
http.Error(w, "session not found", http.StatusForbidden)
return
Expand Down Expand Up @@ -365,7 +364,7 @@ func (s *Service) BeginAddDevice() http.HandlerFunc {
// We complete the registration.
func (s *Service) FinishAddDevice() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
claims, ok := auth.GetClaimsFromRequest(r)
claims, ok := jwt.GetClaimsFromRequest(r)
if !ok {
http.Error(w, "session not found", http.StatusForbidden)
return
Expand Down Expand Up @@ -431,7 +430,7 @@ func (s *Service) FinishAddDevice() http.HandlerFunc {
}

cookie := &http.Cookie{
Name: auth.TokenCookieKey,
Name: jwt.TokenCookieKey,
Value: token,
Path: "/",
Expires: time.Now().Add(jwt.ExpiresDuration),
Expand All @@ -458,7 +457,7 @@ func (s *Service) DeleteDevice() http.HandlerFunc {
return
}

claims, ok := auth.GetClaimsFromRequest(r)
claims, ok := jwt.GetClaimsFromRequest(r)
if !ok {
http.Error(w, "session not found", http.StatusForbidden)
return
Expand Down
4 changes: 2 additions & 2 deletions handler/handler_counter.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ import (
"fmt"
"net/http"

"github.com/Darkness4/auth-htmx/auth"
"github.com/Darkness4/auth-htmx/database/counter"
"github.com/Darkness4/auth-htmx/jwt"
)

// Count increments the counter and returns the new value.
func Count(counter counter.Repository) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
claims, ok := auth.GetClaimsFromRequest(r)
claims, ok := jwt.GetClaimsFromRequest(r)
if !ok {
http.Error(w, "not allowed", http.StatusUnauthorized)
return
Expand Down
38 changes: 38 additions & 0 deletions jwt/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,22 @@
package jwt

import (
"context"
"fmt"
"net/http"
"time"

"github.com/go-webauthn/webauthn/webauthn"
"github.com/golang-jwt/jwt/v5"
)

const (
// TokenCookieKey is the key of the cookie stored in the context.
TokenCookieKey = "session_token"
)

type claimsContextKey struct{}

// ExpiresDuration is the duration when a user session expires.
const ExpiresDuration = 24 * time.Hour

Expand Down Expand Up @@ -106,3 +115,32 @@ func (s Secret) VerifyToken(tokenString string) (*Claims, error) {

return nil, fmt.Errorf("invalid token")
}

// Middleware is an authentication guard for HTTP servers.
func (jwt Secret) Middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Get the JWT token from the request header
cookie, err := r.Cookie(TokenCookieKey)
if err != nil {
next.ServeHTTP(w, r)
return
}

// Verify the JWT token
claims, err := jwt.VerifyToken(cookie.Value)
if err != nil {
next.ServeHTTP(w, r)
return
}

// Store the claims in the request context for further use
ctx := context.WithValue(r.Context(), claimsContextKey{}, *claims)
next.ServeHTTP(w, r.WithContext(ctx))
})
}

// GetClaimsFromRequest is a helper function to fetch the JWT session token from an HTTP request.
func GetClaimsFromRequest(r *http.Request) (claims Claims, ok bool) {
claims, ok = r.Context().Value(claimsContextKey{}).(Claims)
return claims, ok
}
23 changes: 3 additions & 20 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,25 +133,8 @@ var app = &cli.App{

// Router
r := chi.NewRouter()
r.Use(jwt.Secret(jwtSecret).Middleware)
r.Use(hlog.NewHandler(log.Logger))
r.Use(authService.Middleware)

// Auth Guard
r.Use(func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, isAuth := auth.GetClaimsFromRequest(r)

if !isAuth {
switch r.URL.Path {
case "/counter":
http.Error(w, "unauthorized", http.StatusUnauthorized)
return
}
}

next.ServeHTTP(w, r)
})
})

// DB
d, err := sql.Open("sqlite", dbFile)
Expand All @@ -166,7 +149,7 @@ var app = &cli.App{

// Auth
r.Get("/login", authService.Login())
r.Get("/logout", authService.Logout())
r.Get("/logout", auth.Logout)
r.Get("/callback", authService.CallBack())

u, err := url.Parse(publicURL)
Expand Down Expand Up @@ -217,7 +200,7 @@ var app = &cli.App{
path := filepath.Clean(r.URL.Path)
path = filepath.Clean(fmt.Sprintf("pages/%s/page.tmpl", path))

claims, _ := auth.GetClaimsFromRequest(r)
claims, _ := jwt.GetClaimsFromRequest(r)

// Check if SSR
var base string
Expand Down

0 comments on commit 60c5e59

Please sign in to comment.