diff --git a/auth/auth.go b/auth/auth.go index 8ff5d7d..ebe5738 100644 --- a/auth/auth.go +++ b/auth/auth.go @@ -2,7 +2,6 @@ package auth import ( - "context" "fmt" "net/http" "net/url" @@ -14,13 +13,6 @@ import ( "github.com/rs/zerolog/log" ) -const ( - // TokenCookieKey is the key of the cookie stored in the context. - TokenCookieKey = "session_token" -) - -type claimsContextKey struct{} - // Auth is a service that provides HTTP handlers and middlewares used for authentication. type Auth struct { JWTSecret jwt.Secret @@ -123,7 +115,7 @@ func (a *Auth) CallBack() http.HandlerFunc { } cookie := &http.Cookie{ - Name: TokenCookieKey, + Name: jwt.TokenCookieKey, Value: token, Path: "/", Expires: time.Now().Add(jwt.ExpiresDuration), @@ -135,48 +127,16 @@ func (a *Auth) CallBack() http.HandlerFunc { } // Logout removes session cookies and redirect to home. -func (a *Auth) Logout() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - cookie, err := r.Cookie(TokenCookieKey) - if err != nil { - // Ignore error. Cookie doesn't exists. - http.Redirect(w, r, "/", http.StatusSeeOther) - return - } - cookie.Value = "" - cookie.Path = "/" - cookie.Expires = time.Now().Add(-1 * time.Hour) - http.SetCookie(w, cookie) +func Logout(w http.ResponseWriter, r *http.Request) { + cookie, err := r.Cookie(jwt.TokenCookieKey) + if err != nil { + // Ignore error. Cookie doesn't exists. http.Redirect(w, r, "/", http.StatusSeeOther) + return } -} - -// Middleware is an authentication guard for HTTP servers. -func (a *Auth) Middleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Get the JWT token from the request header - cookie, err := r.Cookie(TokenCookieKey) - if err != nil { - next.ServeHTTP(w, r) - return - } - - // Verify the JWT token - claims, err := a.JWTSecret.VerifyToken(cookie.Value) - if err != nil { - log.Error().Err(err).Msg("token verification failed") - next.ServeHTTP(w, r) - return - } - - // Store the claims in the request context for further use - ctx := context.WithValue(r.Context(), claimsContextKey{}, *claims) - next.ServeHTTP(w, r.WithContext(ctx)) - }) -} - -// GetClaimsFromRequest is a helper function to fetch the JWT session token from an HTTP request. -func GetClaimsFromRequest(r *http.Request) (claims jwt.Claims, ok bool) { - claims, ok = r.Context().Value(claimsContextKey{}).(jwt.Claims) - return claims, ok + cookie.Value = "" + cookie.Path = "/" + cookie.Expires = time.Now().Add(-1 * time.Hour) + http.SetCookie(w, cookie) + http.Redirect(w, r, "/", http.StatusSeeOther) } diff --git a/auth/webauthn/webauthn.go b/auth/webauthn/webauthn.go index 307eeca..d4e9c16 100644 --- a/auth/webauthn/webauthn.go +++ b/auth/webauthn/webauthn.go @@ -7,7 +7,6 @@ import ( "net/http" "time" - "github.com/Darkness4/auth-htmx/auth" "github.com/Darkness4/auth-htmx/auth/webauthn/session" "github.com/Darkness4/auth-htmx/database/user" "github.com/Darkness4/auth-htmx/jwt" @@ -162,7 +161,7 @@ func (s *Service) FinishLogin() http.HandlerFunc { } cookie := &http.Cookie{ - Name: auth.TokenCookieKey, + Name: jwt.TokenCookieKey, Value: token, Path: "/", Expires: time.Now().Add(jwt.ExpiresDuration), @@ -290,7 +289,7 @@ func (s *Service) FinishRegistration() http.HandlerFunc { } cookie := &http.Cookie{ - Name: auth.TokenCookieKey, + Name: jwt.TokenCookieKey, Value: token, Path: "/", Expires: time.Now().Add(jwt.ExpiresDuration), @@ -309,7 +308,7 @@ func (s *Service) FinishRegistration() http.HandlerFunc { // Compared to BeginRegistration, BeginAddDevice uses the JWT to allow the registration. func (s *Service) BeginAddDevice() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - claims, ok := auth.GetClaimsFromRequest(r) + claims, ok := jwt.GetClaimsFromRequest(r) if !ok { http.Error(w, "session not found", http.StatusForbidden) return @@ -365,7 +364,7 @@ func (s *Service) BeginAddDevice() http.HandlerFunc { // We complete the registration. func (s *Service) FinishAddDevice() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - claims, ok := auth.GetClaimsFromRequest(r) + claims, ok := jwt.GetClaimsFromRequest(r) if !ok { http.Error(w, "session not found", http.StatusForbidden) return @@ -431,7 +430,7 @@ func (s *Service) FinishAddDevice() http.HandlerFunc { } cookie := &http.Cookie{ - Name: auth.TokenCookieKey, + Name: jwt.TokenCookieKey, Value: token, Path: "/", Expires: time.Now().Add(jwt.ExpiresDuration), @@ -458,7 +457,7 @@ func (s *Service) DeleteDevice() http.HandlerFunc { return } - claims, ok := auth.GetClaimsFromRequest(r) + claims, ok := jwt.GetClaimsFromRequest(r) if !ok { http.Error(w, "session not found", http.StatusForbidden) return diff --git a/handler/handler_counter.go b/handler/handler_counter.go index f5b342b..2a9bf7d 100644 --- a/handler/handler_counter.go +++ b/handler/handler_counter.go @@ -5,14 +5,14 @@ import ( "fmt" "net/http" - "github.com/Darkness4/auth-htmx/auth" "github.com/Darkness4/auth-htmx/database/counter" + "github.com/Darkness4/auth-htmx/jwt" ) // Count increments the counter and returns the new value. func Count(counter counter.Repository) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - claims, ok := auth.GetClaimsFromRequest(r) + claims, ok := jwt.GetClaimsFromRequest(r) if !ok { http.Error(w, "not allowed", http.StatusUnauthorized) return diff --git a/jwt/jwt.go b/jwt/jwt.go index 08eb5c0..66325f1 100644 --- a/jwt/jwt.go +++ b/jwt/jwt.go @@ -2,13 +2,22 @@ package jwt import ( + "context" "fmt" + "net/http" "time" "github.com/go-webauthn/webauthn/webauthn" "github.com/golang-jwt/jwt/v5" ) +const ( + // TokenCookieKey is the key of the cookie stored in the context. + TokenCookieKey = "session_token" +) + +type claimsContextKey struct{} + // ExpiresDuration is the duration when a user session expires. const ExpiresDuration = 24 * time.Hour @@ -106,3 +115,32 @@ func (s Secret) VerifyToken(tokenString string) (*Claims, error) { return nil, fmt.Errorf("invalid token") } + +// Middleware is an authentication guard for HTTP servers. +func (jwt Secret) Middleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Get the JWT token from the request header + cookie, err := r.Cookie(TokenCookieKey) + if err != nil { + next.ServeHTTP(w, r) + return + } + + // Verify the JWT token + claims, err := jwt.VerifyToken(cookie.Value) + if err != nil { + next.ServeHTTP(w, r) + return + } + + // Store the claims in the request context for further use + ctx := context.WithValue(r.Context(), claimsContextKey{}, *claims) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +// GetClaimsFromRequest is a helper function to fetch the JWT session token from an HTTP request. +func GetClaimsFromRequest(r *http.Request) (claims Claims, ok bool) { + claims, ok = r.Context().Value(claimsContextKey{}).(Claims) + return claims, ok +} diff --git a/main.go b/main.go index 6836b67..603186d 100644 --- a/main.go +++ b/main.go @@ -133,25 +133,8 @@ var app = &cli.App{ // Router r := chi.NewRouter() + r.Use(jwt.Secret(jwtSecret).Middleware) r.Use(hlog.NewHandler(log.Logger)) - r.Use(authService.Middleware) - - // Auth Guard - r.Use(func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, isAuth := auth.GetClaimsFromRequest(r) - - if !isAuth { - switch r.URL.Path { - case "/counter": - http.Error(w, "unauthorized", http.StatusUnauthorized) - return - } - } - - next.ServeHTTP(w, r) - }) - }) // DB d, err := sql.Open("sqlite", dbFile) @@ -166,7 +149,7 @@ var app = &cli.App{ // Auth r.Get("/login", authService.Login()) - r.Get("/logout", authService.Logout()) + r.Get("/logout", auth.Logout) r.Get("/callback", authService.CallBack()) u, err := url.Parse(publicURL) @@ -217,7 +200,7 @@ var app = &cli.App{ path := filepath.Clean(r.URL.Path) path = filepath.Clean(fmt.Sprintf("pages/%s/page.tmpl", path)) - claims, _ := auth.GetClaimsFromRequest(r) + claims, _ := jwt.GetClaimsFromRequest(r) // Check if SSR var base string