Skip to content

Commit

Permalink
Added role based auth
Browse files Browse the repository at this point in the history
  • Loading branch information
Franco Ferraguti committed Oct 5, 2023
1 parent f483d79 commit cdb3d4d
Show file tree
Hide file tree
Showing 23 changed files with 392 additions and 194 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@

# Output of the go coverage tool, specifically when used with LiteIDE
*.out

# VSCode
.vscode/
43 changes: 33 additions & 10 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package main

import (
"log"
"os"

"github.com/gilperopiola/go-rest-example/pkg/auth"
"github.com/gilperopiola/go-rest-example/pkg/codec"
Expand All @@ -14,23 +13,47 @@ import (

func main() {

// It all starts here!
log.Println("Server started")

// Setup dependencies
var (
config = config.NewConfig()
auth = auth.NewAuth(config.JWT.SECRET, config.JWT.SESSION_DURATION_DAYS)
codec = codec.NewCodec()
database = repository.NewDatabase(config.DATABASE)

// Load configuration settings
config = config.NewConfig()

// Initialize authentication module
auth = auth.NewAuth(config.JWT.SECRET, config.JWT.SESSION_DURATION_DAYS)

// Setup codec for encoding and decoding
codec = codec.NewCodec()

// Establish database connection
database = repository.NewDatabase(config.DATABASE)

// Initialize repository with the database connection
repository = repository.NewRepository(database)
service = service.NewService(repository, auth, codec, config, service.ErrorsMapper{})
endpoints = transport.NewTransport(service, codec, transport.ErrorsMapper{})
router = transport.NewRouter(endpoints, config, auth)

// Setup the main service with dependencies
service = service.NewService(repository, auth, codec, config, service.ErrorsMapper{})

// Setup endpoints & transport layer with dependencies
endpoints = transport.NewTransport(service, codec, transport.ErrorsMapper{})

// Initialize the router with the endpoints
router = transport.NewRouter(endpoints, config, auth)
)

// Defer closing open connections
defer database.Close()

// Start server
log.Println("Server running")
router.Run(":" + os.Getenv("PORT"))
log.Println("About to run server on port " + config.PORT)

err := router.Run(":" + config.PORT)
if err != nil {
log.Fatalf("Failed to start server: %v", err)
}

// Have a nice day!
}
18 changes: 0 additions & 18 deletions config.json

This file was deleted.

88 changes: 19 additions & 69 deletions pkg/auth/auth.go
Original file line number Diff line number Diff line change
@@ -1,32 +1,23 @@
package auth

import (
"fmt"
"net/http"
"strings"
"time"

"github.com/gilperopiola/go-rest-example/pkg/entities"

"github.com/dgrijalva/jwt-go"
"github.com/gin-gonic/gin"
)

const (
UNAUTHORIZED = "unauthorized"
)

type Auth struct {
secret string
sessionDurationDays int
}

type AuthProvider interface {
GenerateToken(user entities.User) string
type AuthInterface interface {
GenerateToken(user entities.User, role AuthRole) string
ValidateToken() gin.HandlerFunc

decodeToken(tokenString string) (*jwt.Token, error)
getTokenStringFromHeaders(c *gin.Context) string
ValidateRole(role AuthRole) gin.HandlerFunc
GetUserRole() AuthRole
GetAdminRole() AuthRole
}

func NewAuth(secret string, sessionDurationDays int) *Auth {
Expand All @@ -36,65 +27,24 @@ func NewAuth(secret string, sessionDurationDays int) *Auth {
}
}

/* ----------------------- */

func (auth *Auth) GenerateToken(user entities.User) string {

issuedAt := time.Now().Unix()
expiresAt := time.Now().Add(time.Hour * 24 * time.Duration(auth.sessionDurationDays)).Unix()

token := jwt.NewWithClaims(jwt.SigningMethodHS256,
jwt.StandardClaims{
Id: fmt.Sprint(user.ID),
Audience: user.Email,
IssuedAt: issuedAt,
ExpiresAt: expiresAt,
},
)

tokenString, _ := token.SignedString([]byte(auth.secret))
return tokenString
type CustomClaims struct {
Username string `json:"username"`
Email string `json:"email"`
Role AuthRole `json:"role"`
jwt.StandardClaims
}

func (auth *Auth) ValidateToken() gin.HandlerFunc {
return func(c *gin.Context) {

// Get token string from context
tokenString := auth.getTokenStringFromHeaders(c)
type AuthRole string

// Decode string into actual *jwt.Token
token, err := auth.decodeToken(tokenString)
if err != nil {
c.JSON(http.StatusUnauthorized, UNAUTHORIZED)
c.Abort()
return
}

// Check if token is valid, then set ID and Email in context
if claims, ok := token.Claims.(*jwt.StandardClaims); ok && token.Valid {
c.Set("ID", claims.Id)
c.Set("Email", claims.Audience)
} else {
c.JSON(http.StatusUnauthorized, UNAUTHORIZED)
c.Abort()
}
}
}

// decodeToken decodes a JWT token string into a *jwt.Token
func (auth *Auth) decodeToken(tokenString string) (*jwt.Token, error) {
if len(tokenString) < 40 {
return &jwt.Token{}, nil
}

keyFunc := func(token *jwt.Token) (interface{}, error) {
return []byte(auth.secret), nil
}
const (
UserRole AuthRole = "user"
AdminRole AuthRole = "admin"
)

return jwt.ParseWithClaims(tokenString, &jwt.StandardClaims{}, keyFunc)
func (auth *Auth) GetUserRole() AuthRole {
return UserRole
}

func (auth *Auth) getTokenStringFromHeaders(c *gin.Context) string {
tokenString := c.Request.Header.Get("Authorization")
return strings.TrimPrefix(tokenString, "Bearer ")
func (auth *Auth) GetAdminRole() AuthRole {
return AdminRole
}
38 changes: 38 additions & 0 deletions pkg/auth/token_generation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package auth

import (
"fmt"
"time"

"github.com/gilperopiola/go-rest-example/pkg/entities"

"github.com/dgrijalva/jwt-go"
)

func (auth *Auth) GenerateToken(user entities.User, role AuthRole) string {

var (
issuedAt = time.Now().Unix()
expiresAt = time.Now().Add(time.Hour * 24 * time.Duration(auth.sessionDurationDays)).Unix()
)

claims := &CustomClaims{
Username: user.Username,
Email: user.Email,
Role: role,
StandardClaims: jwt.StandardClaims{
Id: fmt.Sprint(user.ID),
Audience: user.Email,
IssuedAt: issuedAt,
ExpiresAt: expiresAt,
},
}

// Generate token (struct)
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)

// Generate token (string)
tokenString, _ := token.SignedString([]byte(auth.secret))

return tokenString
}
94 changes: 94 additions & 0 deletions pkg/auth/token_validation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package auth

import (
"net/http"
"strings"

"github.com/dgrijalva/jwt-go"
"github.com/gin-gonic/gin"
)

const (
UnauthorizedMsg = "unauthorized"
)

// ValidateToken validates a token for any role and sets ID and Email in context
func (auth *Auth) ValidateToken() gin.HandlerFunc {
return func(c *gin.Context) {

// Get token first as a string and then as a *jwt.Token
token := auth.getTokenStructFromContext(c)

// Check if token is valid, then set ID and Email in context
if claims, ok := token.Claims.(*CustomClaims); ok && token.Valid {
c.Set("ID", claims.Id)
c.Set("Email", claims.Audience)
return
}

c.JSON(http.StatusUnauthorized, UnauthorizedMsg)
c.Abort()
}
}

// ValidateRole validates a token for a specific role and sets ID and Email in context
func (auth *Auth) ValidateRole(role AuthRole) gin.HandlerFunc {
return func(c *gin.Context) {

// Get token first as a string and then as a *jwt.Token
token := auth.getTokenStructFromContext(c)

// Get custom claims from token
customClaims, ok := token.Claims.(*CustomClaims)

// Check if token is valid, then set ID and Email in context
if ok && token.Valid && customClaims.Role == role {
c.Set("ID", customClaims.Id)
c.Set("Email", customClaims.Audience)
return
}

c.JSON(http.StatusUnauthorized, UnauthorizedMsg)
c.Abort()
}
}

func (auth *Auth) getTokenStructFromContext(c *gin.Context) *jwt.Token {

// Get token string from context
tokenString := removeBearerPrefix(auth.getJWTStringFromHeader(c.Request.Header))

// Decode string into actual *jwt.Token
token, err := auth.decodeTokenString(tokenString)
if err == nil {
return token
}

// Error decoding token
c.JSON(http.StatusUnauthorized, UnauthorizedMsg)
c.Abort()
return nil
}

// decodeTokenString decodes a JWT token string into a *jwt.Token
func (auth *Auth) decodeTokenString(tokenString string) (*jwt.Token, error) {

// Check length
if len(tokenString) < 40 {
return &jwt.Token{}, nil
}

// Make key function
keyFunc := func(token *jwt.Token) (interface{}, error) { return []byte(auth.secret), nil }

// Parse
return jwt.ParseWithClaims(tokenString, &CustomClaims{}, keyFunc)
}

func (auth *Auth) getJWTStringFromHeader(header http.Header) string {
return header.Get("Authorization")
}

func removeBearerPrefix(token string) string {
return strings.TrimPrefix(token, "Bearer ")
}
2 changes: 1 addition & 1 deletion pkg/codec/codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (

type Codec struct{}

type CodecProvider interface {
type CodecInterface interface {

// From Requests to Models
FromSignupRequestToUserModel(request entities.SignupRequest, hashedPassword string) models.User
Expand Down
Loading

0 comments on commit cdb3d4d

Please sign in to comment.