diff --git a/.gitignore b/.gitignore index 6ab87a2..a8d4f8f 100644 --- a/.gitignore +++ b/.gitignore @@ -41,3 +41,7 @@ node_modules/ # Windows specific files Thumbs.db + +# Docker files +Dockerfile +docker-compose.yml diff --git a/internal/core/middleware/CSRF.go b/internal/core/middleware/CSRF.go index 9f42b6b..2cccb04 100644 --- a/internal/core/middleware/CSRF.go +++ b/internal/core/middleware/CSRF.go @@ -4,6 +4,7 @@ import ( "crypto/rand" "encoding/base64" "io" + "log" "net/http" ) @@ -16,13 +17,17 @@ func NewCSRFProtection() *CSRFProtection { func (csrf *CSRFProtection) Handle(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodGet { - // Generate and set CSRF token for GET requests - token, err := GenerateCSRFToken() + // Retrieve or set CSRF token for GET requests + _, err := getCSRFCookie(r) if err != nil { - http.Error(w, "Failed to generate CSRF token", http.StatusInternalServerError) - return + // Generate and set a new CSRF token if not present + token, err := GenerateCSRFToken() + if err != nil { + http.Error(w, "Failed to generate CSRF token", http.StatusInternalServerError) + return + } + SetCSRFCookie(w, token) } - SetCSRFCookie(w, token) } else if r.Method == http.MethodPost || r.Method == http.MethodPut || r.Method == http.MethodDelete { // Validate CSRF token for state-changing requests if !ValidateCSRFToken(r) { @@ -34,6 +39,7 @@ func (csrf *CSRFProtection) Handle(next http.Handler) http.Handler { }) } +// GenerateCSRFToken generates a new CSRF token. func GenerateCSRFToken() (string, error) { token := make([]byte, 32) // 32 bytes = 256 bits if _, err := io.ReadFull(rand.Reader, token); err != nil { @@ -53,11 +59,22 @@ func SetCSRFCookie(w http.ResponseWriter, token string) { }) } +// getCSRFCookie retrieves the CSRF token from the cookie, if present. +func getCSRFCookie(r *http.Request) (string, error) { + cookie, err := r.Cookie("csrf_token") + if err != nil { + return "", err + } + return cookie.Value, nil +} + +// ValidateCSRFToken validates the CSRF token from the request header or form data. func ValidateCSRFToken(r *http.Request) bool { cookie, err := r.Cookie("csrf_token") if err != nil { + log.Printf("Error retrieving CSRF cookie: %v", err) return false } - csrfToken := r.Header.Get("X-CSRF-Token") // Or retrieve from form data + csrfToken := r.Header.Get("X-CSRF-Token") // Retrieve from request header return csrfToken == cookie.Value } diff --git a/internal/core/middleware/cacher.go b/internal/core/middleware/cacher.go index 6134a56..cfe5107 100644 --- a/internal/core/middleware/cacher.go +++ b/internal/core/middleware/cacher.go @@ -2,6 +2,7 @@ package middleware import ( "context" + "log" "net/http" "time" @@ -14,9 +15,14 @@ type Caching struct { } func NewCaching(redisAddr string, ttl time.Duration) *Caching { + ctx := context.Background() client := redis.NewClient(&redis.Options{ Addr: redisAddr, // e.g., "localhost:6379" }) + _, err := client.Ping(ctx).Result() + if err != nil { + log.Fatalf("Could not connect to Redis: %v", err) + } return &Caching{ client: client, ttl: ttl, @@ -27,20 +33,31 @@ func (c *Caching) Handle(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := context.Background() - // Try to get the cached response from Redis - data, err := c.client.Get(ctx, r.RequestURI).Result() - if err == nil { - // If found in cache, write it directly to the response - w.Write([]byte(data)) - return + if r.Method == http.MethodGet { + // Try to get the cached response from Redis + data, err := c.client.Get(ctx, r.RequestURI).Result() + if err == nil { + // If found in cache, write it directly to the response + w.Header().Set("X-Cache-Hit", "true") + w.Write([]byte(data)) + return + } else if err != redis.Nil { + // Log any errors retrieving from Redis + log.Printf("Error retrieving from cache: %v", err) + } } - // If not cached, create a response writer to capture the response + // Create a response writer to capture the response rec := &responseRecorder{ResponseWriter: w, statusCode: http.StatusOK} next.ServeHTTP(rec, r) - // Cache the response in Redis - c.client.Set(ctx, r.RequestURI, rec.body, c.ttl) + if r.Method == http.MethodGet { + // Cache the response in Redis + err := c.client.Set(ctx, r.RequestURI, rec.body, c.ttl).Err() + if err != nil { + log.Printf("Error setting cache: %v", err) + } + } }) } @@ -54,3 +71,10 @@ func (rec *responseRecorder) Write(p []byte) (int, error) { rec.body = append(rec.body, p...) return rec.ResponseWriter.Write(p) } + +// Implement the Flush method +func (rec *responseRecorder) Flush() { + if flusher, ok := rec.ResponseWriter.(http.Flusher); ok { + flusher.Flush() + } +} diff --git a/internal/core/middleware/json_parser.go b/internal/core/middleware/json_parser.go index c804f76..4d6a7b8 100644 --- a/internal/core/middleware/json_parser.go +++ b/internal/core/middleware/json_parser.go @@ -44,6 +44,8 @@ func NewJsonParser(options ParserOptions) *JSONParser { } } +type JsonKey string + func (jp *JSONParser) Handle(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Header.Get("Content-Type") == "application/json" { @@ -70,7 +72,7 @@ func (jp *JSONParser) Handle(next http.Handler) http.Handler { } // Store the parsed JSON in the context - key := "jsonBody" + key := JsonKey("jsonBody") r = r.WithContext(context.WithValue(r.Context(), key, body)) } next.ServeHTTP(w, r)