diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..258d195 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,4 @@ +db-data +/cmd/server/db-data +.git +.gitignore diff --git a/.gitignore b/.gitignore index 50d277e..4010793 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ c.out .env pgdata auth_test.go +static \ No newline at end of file diff --git a/cmd/server/main.go b/cmd/server/main.go index b5ca740..db465ad 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -2,16 +2,16 @@ package main import ( "context" + "kudago/internal/db" "log" "net/http" - "kudago/internal/db" - "kudago/config" _ "kudago/docs" "kudago/internal/http/auth" "kudago/internal/http/events" "kudago/internal/middleware" + csrfRepository "kudago/internal/repository/csrf" eventRepository "kudago/internal/repository/events" sessionRepository "kudago/internal/repository/session" userRepository "kudago/internal/repository/users" @@ -33,6 +33,7 @@ import ( func main() { port := config.LoadConfig() + encryptionKey := config.LoadEncriptionKey() logger, err := zap.NewProduction() if err != nil { @@ -55,9 +56,10 @@ func main() { userDB := userRepository.NewDB(pool) sessionDB := sessionRepository.NewDB(redisClient) + csrfDB := csrfRepository.NewDB(redisClient) eventDB := eventRepository.NewDB(pool) - authService := authService.NewService(userDB, sessionDB) + authService := authService.NewService(userDB, sessionDB, csrfDB) eventService := eventService.NewService(eventDB) authHandler := auth.NewAuthHandler(&authService) @@ -96,7 +98,7 @@ func main() { "/categories", } - handlerWithAuth := middleware.AuthMiddleware(whitelist, authHandler, r) + handlerWithAuth := middleware.AuthWithCSRFMiddleware(whitelist, authHandler, encryptionKey, r) handlerWithCORS := middleware.CORSMiddleware(handlerWithAuth) handlerWithLogging := middleware.LoggingMiddleware(handlerWithCORS, sugar) handler := middleware.PanicMiddleware(handlerWithLogging) diff --git a/config/config.go b/config/config.go index 0280dc2..af75204 100644 --- a/config/config.go +++ b/config/config.go @@ -18,3 +18,8 @@ func LoadConfig() string { log.Printf("Используется порт: %s", port) return port } + +func LoadEncriptionKey() []byte { + key := os.Getenv("ENCRYPTION_KEY") + return []byte(key) +} diff --git a/docker-compose.yml b/docker-compose.yml index 19f6682..265734e 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -30,8 +30,8 @@ services: redis: image: redis:latest ports: - - "6379:6379" + - "6379:6379" volumes: postgres_data: - driver: local + driver: local \ No newline at end of file diff --git a/internal/http/auth/auth.go b/internal/http/auth/auth.go index f507d37..5893f50 100644 --- a/internal/http/auth/auth.go +++ b/internal/http/auth/auth.go @@ -28,6 +28,8 @@ type AuthService interface { Register(ctx context.Context, user models.User) (models.User, error) CreateSession(ctx context.Context, ID int) (models.Session, error) DeleteSession(ctx context.Context, token string) + CreateCSRF(ctx context.Context, encryptionKey []byte, s *models.Session) (string, error) + CheckCSRF(ctx context.Context, encryptionKey []byte, s *models.Session, inputToken string) (bool, error) } type RegisterRequest struct { @@ -291,3 +293,10 @@ func userToProfileResponse(user models.User) ProfileResponse { func (h *AuthHandler) CheckSessionMiddleware(ctx context.Context, cookie string) (models.Session, bool) { return h.service.CheckSession(ctx, cookie) } +func (h *AuthHandler) CheckCSRFMiddleware(ctx context.Context, encryptionKey []byte, s *models.Session, inputToken string) (bool, error) { + return h.service.CheckCSRF(ctx, encryptionKey, s, inputToken) +} + +func (h *AuthHandler) CreateCSRFMiddleware(ctx context.Context, encryptionKey []byte, s *models.Session) (string, error) { + return h.service.CreateCSRF(ctx, encryptionKey, s) +} diff --git a/internal/http/errors/errors.go b/internal/http/errors/errors.go index ec630bb..04e7531 100644 --- a/internal/http/errors/errors.go +++ b/internal/http/errors/errors.go @@ -70,4 +70,19 @@ var ( Message: "User doesn't own this event", Code: "access_denied", } + + ErrCSRFTokenMissing = &HttpError{ + Message: "CSRF token is missing or invalid", + Code: "csrf_missing", + } + + ErrInvalidCSRFToken = &HttpError{ + Message: "Invalid CSRF token", + Code: "invalid_csrf_token", + } + + ErrCSRFTokenGenerationFailed = &HttpError{ + Message: "Failed to generate CSRF token", + Code: "csrf_token_generation_failed", + } ) diff --git a/internal/http/utils/utils.go b/internal/http/utils/utils.go index 61c98f9..c110442 100644 --- a/internal/http/utils/utils.go +++ b/internal/http/utils/utils.go @@ -22,6 +22,10 @@ type sessionKeyType struct{} var sessionKey sessionKeyType +type csrfKeyType struct{} + +var csrfKey csrfKeyType + type requestIDKeyType struct{} var requestIDKey requestIDKeyType @@ -43,6 +47,18 @@ func SetSessionInContext(ctx context.Context, session models.Session) context.Co return context.WithValue(ctx, sessionKey, session) } +func GetCSRFFromContext(ctx context.Context) (models.TokenData, bool) { + csrfKey, ok := ctx.Value(csrfKey).(models.TokenData) + if !ok { + return csrfKey, false + } + return csrfKey, true +} + +func SetCSRFInContext(ctx context.Context, token models.TokenData) context.Context { + return context.WithValue(ctx, csrfKey, token) +} + func GetRequestIDFromContext(ctx context.Context) string { ID, _ := ctx.Value(requestIDKey).(string) return ID diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go index 0378f18..527efef 100644 --- a/internal/middleware/auth.go +++ b/internal/middleware/auth.go @@ -1,26 +1,57 @@ package middleware import ( + httpErrors "kudago/internal/http/errors" "net/http" "strings" "kudago/internal/http/auth" "kudago/internal/http/utils" + "kudago/internal/models" ) const ( SessionToken = "session_token" - SessionKey = "session" ) -func AuthMiddleware(whitelist []string, authHandler *auth.AuthHandler, next http.Handler) http.Handler { +func AuthWithCSRFMiddleware(whitelist []string, authHandler *auth.AuthHandler, encryptionKey []byte, next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cookie, err := r.Cookie(SessionToken) if err == nil { session, authenticated := authHandler.CheckSessionMiddleware(r.Context(), cookie.Value) if authenticated { ctx := utils.SetSessionInContext(r.Context(), session) r = r.WithContext(ctx) + + if r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodOptions { + csrfToken := r.Header.Get("X-CSRF-Token") + if csrfToken == "" { + utils.WriteResponse(w, http.StatusForbidden, httpErrors.ErrCSRFTokenMissing) + return + } + + valid, err := authHandler.CheckCSRFMiddleware(r.Context(), encryptionKey, &session, csrfToken) + if err != nil || !valid { + utils.WriteResponse(w, http.StatusForbidden, httpErrors.ErrInvalidCSRFToken) + return + } + } else if r.Method == http.MethodGet { + session, ok := utils.GetSessionFromContext(r.Context()) + if ok { + csrfToken, err := authHandler.CreateCSRFMiddleware(r.Context(), encryptionKey, &session) + if err != nil { + utils.WriteResponse(w, http.StatusInternalServerError, httpErrors.ErrCSRFTokenGenerationFailed) + return + } + + w.Header().Set("X-CSRF-Token", csrfToken) + + ctx := utils.SetCSRFInContext(r.Context(), models.TokenData{CSRFtoken: csrfToken}) + r = r.WithContext(ctx) + } + } + next.ServeHTTP(w, r) return } @@ -34,6 +65,5 @@ func AuthMiddleware(whitelist []string, authHandler *auth.AuthHandler, next http } http.Error(w, "Unauthorized", http.StatusUnauthorized) - return }) } diff --git a/internal/models/csrf_token.go b/internal/models/csrf_token.go new file mode 100644 index 0000000..68a777e --- /dev/null +++ b/internal/models/csrf_token.go @@ -0,0 +1,10 @@ +package models + +import "time" + +type TokenData struct { + CSRFtoken string + SessionToken string + UserID int + Exp time.Time +} diff --git a/internal/repository/csrf/csrf.go b/internal/repository/csrf/csrf.go new file mode 100644 index 0000000..3f68538 --- /dev/null +++ b/internal/repository/csrf/csrf.go @@ -0,0 +1,76 @@ +package csrf + +import ( + "context" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "time" + + "github.com/redis/go-redis/v9" + "kudago/internal/models" +) + +const CSRFTokenExpTime = 15 * time.Minute + +type csrfDB struct { + client *redis.Client +} + +func NewDB(client *redis.Client) *csrfDB { + return &csrfDB{client: client} +} + +func (db *csrfDB) CreateCSRF(ctx context.Context, encryptionKey []byte, s *models.Session) (string, error) { + block, err := aes.NewCipher(encryptionKey) + if err != nil { + return "", err + } + + aesgcm, err := cipher.NewGCM(block) + if err != nil { + return "", err + } + + nonce := make([]byte, aesgcm.NonceSize()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return "", err + } + + tokenExpTime := time.Now().Add(CSRFTokenExpTime) + td := models.TokenData{ + SessionToken: s.Token, + UserID: s.UserID, + Exp: tokenExpTime, + } + data, _ := json.Marshal(td) + ciphertext := aesgcm.Seal(nil, nonce, data, nil) + + res := append(nonce, ciphertext...) + token := base64.StdEncoding.EncodeToString(res) + + err = db.client.Set(ctx, prefixedKey(s.Token), token, CSRFTokenExpTime).Err() + if err != nil { + return "", fmt.Errorf("failed to store token in Redis: %v", err) + } + + return token, nil +} + +func (db *csrfDB) GetCSRF(ctx context.Context, s *models.Session) (string, error) { + storedToken, err := db.client.Get(ctx, prefixedKey(s.Token)).Result() + if err == redis.Nil { + return "", fmt.Errorf("token not found in Redis") + } else if err != nil { + return "", fmt.Errorf("failed to get token from Redis: %v", err) + } + return storedToken, nil +} + +func prefixedKey(key string) string { + return fmt.Sprintf("%s:%s", key, "csrf") +} diff --git a/internal/service/auth/auth.go b/internal/service/auth/auth.go index ea8c8bc..abc6ccc 100644 --- a/internal/service/auth/auth.go +++ b/internal/service/auth/auth.go @@ -2,6 +2,12 @@ package authService import ( "context" + "crypto/aes" + "crypto/cipher" + "encoding/base64" + "encoding/json" + "fmt" + "time" "kudago/internal/models" ) @@ -9,6 +15,7 @@ import ( type authService struct { UserDB UserDB SessionDB SessionDB + CsrfDB CsrfDB } type UserDB interface { @@ -24,8 +31,13 @@ type SessionDB interface { DeleteSession(ctx context.Context, token string) } -func NewService(userDB UserDB, sessionDB SessionDB) authService { - return authService{UserDB: userDB, SessionDB: sessionDB} +type CsrfDB interface { + CreateCSRF(ctx context.Context, encryptionKey []byte, s *models.Session) (string, error) + GetCSRF(ctx context.Context, s *models.Session) (string, error) +} + +func NewService(userDB UserDB, sessionDB SessionDB, csrfDB CsrfDB) authService { + return authService{UserDB: userDB, SessionDB: sessionDB, CsrfDB: csrfDB} } func (a *authService) CheckSession(ctx context.Context, cookie string) (models.Session, bool) { @@ -62,3 +74,55 @@ func (a *authService) CreateSession(ctx context.Context, ID int) (models.Session func (a *authService) DeleteSession(ctx context.Context, token string) { a.SessionDB.DeleteSession(ctx, token) } +func (a *authService) CreateCSRF(ctx context.Context, encryptionKey []byte, s *models.Session) (string, error) { + return a.CsrfDB.CreateCSRF(ctx, encryptionKey, s) +} + +func (a *authService) CheckCSRF(ctx context.Context, encryptionKey []byte, s *models.Session, inputToken string) (bool, error) { + storedToken, err := a.CsrfDB.GetCSRF(ctx, s) + if err != nil { + return false, err + } + + if storedToken != inputToken { + return false, fmt.Errorf("invalid token") + } + + block, err := aes.NewCipher(encryptionKey) + if err != nil { + return false, err + } + + aesgcm, err := cipher.NewGCM(block) + if err != nil { + return false, err + } + + ciphertext, err := base64.StdEncoding.DecodeString(inputToken) + if err != nil { + return false, err + } + + nonceSize := aesgcm.NonceSize() + if len(ciphertext) < nonceSize { + return false, fmt.Errorf("short ciphertext") + } + + nonce, ciphertext := ciphertext[:nonceSize], ciphertext[nonceSize:] + plaintext, err := aesgcm.Open(nil, nonce, ciphertext, nil) + if err != nil { + return false, fmt.Errorf("decrypt fail: %v", err) + } + + td := &models.TokenData{} + err = json.Unmarshal(plaintext, &td) + if err != nil { + return false, fmt.Errorf("bad json: %v", err) + } + + if td.Exp.Unix() < time.Now().Unix() { + return false, fmt.Errorf("token expired") + } + + return s.Token == td.SessionToken && s.UserID == td.UserID, nil +} diff --git a/static/images/50ZHs2a9FW7smiOw_1730731958354.png b/static/images/50ZHs2a9FW7smiOw_1730731958354.png new file mode 100644 index 0000000..3256110 Binary files /dev/null and b/static/images/50ZHs2a9FW7smiOw_1730731958354.png differ diff --git a/static/images/7iXdGW2jezgYjf_O_1730731977372.png b/static/images/7iXdGW2jezgYjf_O_1730731977372.png new file mode 100644 index 0000000..3256110 Binary files /dev/null and b/static/images/7iXdGW2jezgYjf_O_1730731977372.png differ diff --git a/static/images/R-m1BgDB9rPrO1Vk_1730745578724.png b/static/images/R-m1BgDB9rPrO1Vk_1730745578724.png new file mode 100644 index 0000000..3256110 Binary files /dev/null and b/static/images/R-m1BgDB9rPrO1Vk_1730745578724.png differ diff --git a/static/images/vmXI9F6xZz2vhoFG_1730732264682.png b/static/images/vmXI9F6xZz2vhoFG_1730732264682.png new file mode 100644 index 0000000..3256110 Binary files /dev/null and b/static/images/vmXI9F6xZz2vhoFG_1730732264682.png differ diff --git a/static/images/ymGmjzxqZCyU3GVH_1730731901641.png b/static/images/ymGmjzxqZCyU3GVH_1730731901641.png new file mode 100644 index 0000000..3256110 Binary files /dev/null and b/static/images/ymGmjzxqZCyU3GVH_1730731901641.png differ