diff --git a/internal/domain/auth/jwt.go b/internal/domain/auth/jwt.go index aa3a3a4..58b21ca 100644 --- a/internal/domain/auth/jwt.go +++ b/internal/domain/auth/jwt.go @@ -40,7 +40,7 @@ func GenerateJwtToken(keyId int, userId int) (string, error) { issuedAt := time.Now() claims["iat"] = issuedAt.Unix() claims["exp"] = issuedAt.Add(expiresIn).Unix() - claims["sub"] = strconv.Itoa(int(userId)) + claims["sub"] = strconv.Itoa(userId) tokenString, err := token.SignedString(secretKey) if err != nil { @@ -65,44 +65,43 @@ func GenerateTokenPair(userId int) (accessToken, refreshToken string, err error) return } -func ValidateToken(tokenString string) (isValid bool, userId int, keyId int, err error) { - - var key []byte +func ValidateJwtToken(tokenString string) (isValid bool, sub string, keyId int, err error) { + var secret []byte claims := jwt.MapClaims{} - token, err := jwt.ParseWithClaims( + parsedToken, err := jwt.ParseWithClaims( tokenString, claims, func(token *jwt.Token) (interface{}, error) { - keyId = int(token.Header["kid"].(float64)) switch keyId { case AccessTokenKeyId: - key = []byte(config.Global.JwtAccessTokenSecret) + secret = []byte(config.Global.JwtAccessTokenSecret) case RefreshTokenKeyId: - key = []byte(config.Global.JwtRefreshTokenSecret) + secret = []byte(config.Global.JwtRefreshTokenSecret) } - return key, nil + return secret, nil }, ) - if err != nil { common.Logger.Error("JWT validation failed: ", err) return } - if !token.Valid { + if !parsedToken.Valid { err = errors.New("Invalid JWT token") return } isValid = true - sub, err := claims.GetSubject() - userId, err = strconv.Atoi(sub) + sub, err = parsedToken.Claims.GetSubject() + if err != nil { + return + } return } diff --git a/internal/domain/auth/jwt_test.go b/internal/domain/auth/jwt_test.go index fa8f046..eba9085 100644 --- a/internal/domain/auth/jwt_test.go +++ b/internal/domain/auth/jwt_test.go @@ -21,7 +21,7 @@ func TestGenerateAccessToken(t *testing.T) { accessToken, err := GenerateJwtToken(AccessTokenKeyId, userId) assert.NoError(t, err) - isValid, _, _, _ := ValidateToken(accessToken) + isValid, _, _, _ := ValidateJwtToken(accessToken) if isValid != true { t.Errorf( diff --git a/internal/domain/auth/service.go b/internal/domain/auth/service.go index 834e540..3690981 100644 --- a/internal/domain/auth/service.go +++ b/internal/domain/auth/service.go @@ -3,6 +3,7 @@ package auth import ( "errors" "fmt" + "strconv" "gin-starter/internal/domain/model" "gin-starter/internal/domain/user" @@ -86,9 +87,9 @@ func (s *AuthService) Logout(logoutDto *LogoutDto) error { return s.AuthHelper.BlacklistToken(logoutDto.RefreshToken) } -func (s *AuthService) RefreshToken(logoutDto *LogoutDto) (*LoginResponse, error) { +func (s *AuthService) Refresh(logoutDto *LogoutDto) (*LoginResponse, error) { - isTokenValid, userId, keyId, err := ValidateToken(logoutDto.RefreshToken) + isTokenValid, userIdString, keyId, err := ValidateJwtToken(logoutDto.RefreshToken) if err != nil { return nil, err } @@ -115,6 +116,7 @@ func (s *AuthService) RefreshToken(logoutDto *LogoutDto) (*LoginResponse, error) return nil, err } + userId, err := strconv.Atoi(userIdString) accessToken, refreshToken, err := GenerateTokenPair(userId) if err != nil { return nil, err diff --git a/internal/infrastructure/controller/auth_controller.go b/internal/infrastructure/controller/auth_controller.go index d799fd8..db634c8 100644 --- a/internal/infrastructure/controller/auth_controller.go +++ b/internal/infrastructure/controller/auth_controller.go @@ -107,14 +107,14 @@ func (a *authController) Logout(c *gin.Context) { // @Failure 400 {object} common.HttpError // @Failure 500 {object} common.HttpError // @Router /auth/refresh [post] -func (a *authController) RefreshToken(c *gin.Context) { +func (a *authController) Refresh(c *gin.Context) { var refreshDto auth.LogoutDto if err := c.BindJSON(&refreshDto); err != nil { common.RaiseHttpError(c, http.StatusBadRequest, err) return } - refreshResponse, err := a.AuthService.RefreshToken(&refreshDto) + refreshResponse, err := a.AuthService.Refresh(&refreshDto) if err != nil { if errors.Is(err, auth.ErrTokenBlacklisted) { common.RaiseHttpError(c, http.StatusUnauthorized, err) diff --git a/internal/infrastructure/middleware/auth_middleware.go b/internal/infrastructure/middleware/auth_middleware.go index 4a63ca4..22c1002 100644 --- a/internal/infrastructure/middleware/auth_middleware.go +++ b/internal/infrastructure/middleware/auth_middleware.go @@ -8,9 +8,9 @@ import ( "time" "github.com/gin-gonic/gin" - "github.com/golang-jwt/jwt/v5" "gin-starter/internal/config" + "gin-starter/internal/domain/auth" "gin-starter/internal/domain/model" "gin-starter/internal/infrastructure/controller" "gin-starter/pkg/common" @@ -36,34 +36,18 @@ func AuthMiddleware() gin.HandlerFunc { return } - claims := jwt.MapClaims{} - parsedToken, err := jwt.ParseWithClaims( - accessToken, - claims, - func(token *jwt.Token) (interface{}, error) { - return []byte(config.Global.JwtAccessTokenSecret), nil - }, - ) - + isJwtValid, userIdString, _, err := auth.ValidateJwtToken(accessToken) if err != nil { - if err == jwt.ErrSignatureInvalid { - raiseUnauthorizedError(c, "Invalid authorization token signature") - return - } - raiseUnauthorizedError(c, "Invalid authorization token") - return - } - - if !parsedToken.Valid { - raiseUnauthorizedError(c, "Invalid authorization token") + common.RaiseHttpError( + c, + http.StatusUnauthorized, + err, + ) return } - - userIdString, err := parsedToken.Claims.GetSubject() - if err != nil { - raiseUnauthorizedError(c, "Missing subject in JWT") + if !isJwtValid { + raiseUnauthorizedError(c, "Invalid access token") return - } var user model.User diff --git a/internal/infrastructure/server/router.go b/internal/infrastructure/server/router.go index 3f95e57..1cbae4c 100644 --- a/internal/infrastructure/server/router.go +++ b/internal/infrastructure/server/router.go @@ -13,7 +13,7 @@ func registerAuthRoutes(g *gin.RouterGroup) { g.POST("/register", controllers.Register) g.POST("/login", controllers.Login) g.POST("/logout", controllers.Logout, middleware.AuthMiddleware()) - g.POST("/refresh", controllers.RefreshToken) + g.POST("/refresh", controllers.Refresh) } func registerUserRoutes(g *gin.RouterGroup) {