Skip to content

Commit

Permalink
refactor: validate jwt
Browse files Browse the repository at this point in the history
  • Loading branch information
crazyoptimist committed Jun 16, 2024
1 parent 514dae2 commit 95ca05e
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 44 deletions.
25 changes: 12 additions & 13 deletions internal/domain/auth/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func GenerateJwtToken(keyId int, userId int) (string, error) {
issuedAt := time.Now()
claims["iat"] = issuedAt.Unix()
claims["exp"] = issuedAt.Add(expiresIn).Unix()
claims["sub"] = strconv.Itoa(int(userId))
claims["sub"] = strconv.Itoa(userId)

tokenString, err := token.SignedString(secretKey)
if err != nil {
Expand All @@ -65,44 +65,43 @@ func GenerateTokenPair(userId int) (accessToken, refreshToken string, err error)
return
}

func ValidateToken(tokenString string) (isValid bool, userId int, keyId int, err error) {

var key []byte
func ValidateJwtToken(tokenString string) (isValid bool, sub string, keyId int, err error) {
var secret []byte

claims := jwt.MapClaims{}

token, err := jwt.ParseWithClaims(
parsedToken, err := jwt.ParseWithClaims(
tokenString,
claims,
func(token *jwt.Token) (interface{}, error) {

keyId = int(token.Header["kid"].(float64))

switch keyId {
case AccessTokenKeyId:
key = []byte(config.Global.JwtAccessTokenSecret)
secret = []byte(config.Global.JwtAccessTokenSecret)
case RefreshTokenKeyId:
key = []byte(config.Global.JwtRefreshTokenSecret)
secret = []byte(config.Global.JwtRefreshTokenSecret)
}

return key, nil
return secret, nil
},
)

if err != nil {
common.Logger.Error("JWT validation failed: ", err)
return
}

if !token.Valid {
if !parsedToken.Valid {
err = errors.New("Invalid JWT token")
return
}

isValid = true

sub, err := claims.GetSubject()
userId, err = strconv.Atoi(sub)
sub, err = parsedToken.Claims.GetSubject()
if err != nil {
return
}

return
}
2 changes: 1 addition & 1 deletion internal/domain/auth/jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func TestGenerateAccessToken(t *testing.T) {
accessToken, err := GenerateJwtToken(AccessTokenKeyId, userId)
assert.NoError(t, err)

isValid, _, _, _ := ValidateToken(accessToken)
isValid, _, _, _ := ValidateJwtToken(accessToken)

if isValid != true {
t.Errorf(
Expand Down
6 changes: 4 additions & 2 deletions internal/domain/auth/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package auth
import (
"errors"
"fmt"
"strconv"

"gin-starter/internal/domain/model"
"gin-starter/internal/domain/user"
Expand Down Expand Up @@ -86,9 +87,9 @@ func (s *AuthService) Logout(logoutDto *LogoutDto) error {
return s.AuthHelper.BlacklistToken(logoutDto.RefreshToken)
}

func (s *AuthService) RefreshToken(logoutDto *LogoutDto) (*LoginResponse, error) {
func (s *AuthService) Refresh(logoutDto *LogoutDto) (*LoginResponse, error) {

isTokenValid, userId, keyId, err := ValidateToken(logoutDto.RefreshToken)
isTokenValid, userIdString, keyId, err := ValidateJwtToken(logoutDto.RefreshToken)
if err != nil {
return nil, err
}
Expand All @@ -115,6 +116,7 @@ func (s *AuthService) RefreshToken(logoutDto *LogoutDto) (*LoginResponse, error)
return nil, err
}

userId, err := strconv.Atoi(userIdString)
accessToken, refreshToken, err := GenerateTokenPair(userId)
if err != nil {
return nil, err
Expand Down
4 changes: 2 additions & 2 deletions internal/infrastructure/controller/auth_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,14 @@ func (a *authController) Logout(c *gin.Context) {
// @Failure 400 {object} common.HttpError
// @Failure 500 {object} common.HttpError
// @Router /auth/refresh [post]
func (a *authController) RefreshToken(c *gin.Context) {
func (a *authController) Refresh(c *gin.Context) {
var refreshDto auth.LogoutDto
if err := c.BindJSON(&refreshDto); err != nil {
common.RaiseHttpError(c, http.StatusBadRequest, err)
return
}

refreshResponse, err := a.AuthService.RefreshToken(&refreshDto)
refreshResponse, err := a.AuthService.Refresh(&refreshDto)
if err != nil {
if errors.Is(err, auth.ErrTokenBlacklisted) {
common.RaiseHttpError(c, http.StatusUnauthorized, err)
Expand Down
34 changes: 9 additions & 25 deletions internal/infrastructure/middleware/auth_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ import (
"time"

"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"

"gin-starter/internal/config"
"gin-starter/internal/domain/auth"
"gin-starter/internal/domain/model"
"gin-starter/internal/infrastructure/controller"
"gin-starter/pkg/common"
Expand All @@ -36,34 +36,18 @@ func AuthMiddleware() gin.HandlerFunc {
return
}

claims := jwt.MapClaims{}
parsedToken, err := jwt.ParseWithClaims(
accessToken,
claims,
func(token *jwt.Token) (interface{}, error) {
return []byte(config.Global.JwtAccessTokenSecret), nil
},
)

isJwtValid, userIdString, _, err := auth.ValidateJwtToken(accessToken)
if err != nil {
if err == jwt.ErrSignatureInvalid {
raiseUnauthorizedError(c, "Invalid authorization token signature")
return
}
raiseUnauthorizedError(c, "Invalid authorization token")
return
}

if !parsedToken.Valid {
raiseUnauthorizedError(c, "Invalid authorization token")
common.RaiseHttpError(
c,
http.StatusUnauthorized,
err,
)
return
}

userIdString, err := parsedToken.Claims.GetSubject()
if err != nil {
raiseUnauthorizedError(c, "Missing subject in JWT")
if !isJwtValid {
raiseUnauthorizedError(c, "Invalid access token")
return

}

var user model.User
Expand Down
2 changes: 1 addition & 1 deletion internal/infrastructure/server/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func registerAuthRoutes(g *gin.RouterGroup) {
g.POST("/register", controllers.Register)
g.POST("/login", controllers.Login)
g.POST("/logout", controllers.Logout, middleware.AuthMiddleware())
g.POST("/refresh", controllers.RefreshToken)
g.POST("/refresh", controllers.Refresh)
}

func registerUserRoutes(g *gin.RouterGroup) {
Expand Down

0 comments on commit 95ca05e

Please sign in to comment.