Skip to content

Commit

Permalink
Merge pull request #27 from ulule/forward-header
Browse files Browse the repository at this point in the history
IP Forward header from reverse proxy
  • Loading branch information
novln authored Nov 10, 2017
2 parents eaf9865 + a153298 commit 0d25c13
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 28 deletions.
2 changes: 1 addition & 1 deletion drivers/middleware/gin/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func NewMiddleware(limiter *limiter.Limiter, options ...Option) gin.HandlerFunc

// Handle gin request.
func (middleware *Middleware) Handle(c *gin.Context) {
context, err := middleware.Limiter.Get(c, limiter.GetIPKey(c.Request))
context, err := middleware.Limiter.Get(c, c.ClientIP())
if err != nil {
middleware.OnError(c, err)
c.Abort()
Expand Down
9 changes: 5 additions & 4 deletions drivers/middleware/stdlib/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ import (

// Middleware is the middleware for basic http.Handler.
type Middleware struct {
Limiter *limiter.Limiter
OnError ErrorHandler
OnLimitReached LimitReachedHandler
Limiter *limiter.Limiter
OnError ErrorHandler
OnLimitReached LimitReachedHandler
TrustForwardHeader bool
}

// NewMiddleware return a new instance of a basic HTTP middleware.
Expand All @@ -32,7 +33,7 @@ func NewMiddleware(limiter *limiter.Limiter, options ...Option) *Middleware {
// Handler the middleware handler.
func (middleware *Middleware) Handler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
context, err := middleware.Limiter.Get(r.Context(), limiter.GetIPKey(r))
context, err := middleware.Limiter.Get(r.Context(), limiter.GetIPKey(r, middleware.TrustForwardHeader))
if err != nil {
middleware.OnError(w, r, err)
return
Expand Down
7 changes: 7 additions & 0 deletions drivers/middleware/stdlib/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,10 @@ func WithLimitReachedHandler(handler LimitReachedHandler) Option {
func DefaultLimitReachedHandler(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Limit exceeded", http.StatusTooManyRequests)
}

// WithForwardHeader will configure the Middleware to trust X-Real-IP and X-Forwarded-For headers.
func WithForwardHeader(trusted bool) Option {
return option(func(middleware *Middleware) {
middleware.TrustForwardHeader = trusted
})
}
1 change: 1 addition & 0 deletions examples/gin/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ func main() {

// Launch a simple server.
router := gin.Default()
router.ForwardedByClientIP = true
router.Use(middleware)
router.GET("/", index)
log.Fatal(router.Run(":7777"))
Expand Down
2 changes: 1 addition & 1 deletion examples/http/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func main() {
}

// Create a new middleware with the limiter instance.
middleware := stdlib.NewMiddleware(limiter.New(store, rate))
middleware := stdlib.NewMiddleware(limiter.New(store, rate), stdlib.WithForwardHeader(true))

// Launch a simple server.
http.Handle("/", middleware.Handler(http.HandlerFunc(index)))
Expand Down
31 changes: 17 additions & 14 deletions utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,33 @@ import (
)

// GetIP returns IP address from request.
func GetIP(r *http.Request) net.IP {
ip := r.Header.Get("X-Forwarded-For")
if ip != "" {
parts := strings.Split(ip, ",")
part := strings.TrimSpace(parts[0])
return net.ParseIP(part)
}
func GetIP(r *http.Request, trustForwardHeader ...bool) net.IP {
if len(trustForwardHeader) >= 1 && trustForwardHeader[0] {
ip := r.Header.Get("X-Forwarded-For")
if ip != "" {
parts := strings.SplitN(ip, ",", 2)
part := strings.TrimSpace(parts[0])
return net.ParseIP(part)
}

ip = r.Header.Get("X-Real-IP")
if ip != "" {
return net.ParseIP(ip)
ip = strings.TrimSpace(r.Header.Get("X-Real-IP"))
if ip != "" {
return net.ParseIP(ip)
}
}

host, _, err := net.SplitHostPort(r.RemoteAddr)
remoteAddr := strings.TrimSpace(r.RemoteAddr)
host, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
return net.ParseIP(r.RemoteAddr)
return net.ParseIP(remoteAddr)
}

return net.ParseIP(host)
}

// GetIPKey extracts IP from request and returns hashed IP to use as store key.
func GetIPKey(r *http.Request) string {
return GetIP(r).String()
func GetIPKey(r *http.Request, trustForwardHeader ...bool) string {
return GetIP(r, trustForwardHeader...).String()
}

// Random return a random integer between min and max.
Expand Down
72 changes: 64 additions & 8 deletions utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,34 +37,62 @@ func TestGetIP(t *testing.T) {

scenarios := []struct {
request *http.Request
hasProxy bool
expected net.IP
}{
{
//
// Scenario #1 : RemoteAddr
// Scenario #1 : RemoteAddr without proxy.
//
request: request1,
hasProxy: false,
expected: net.ParseIP("8.8.8.8"),
},
{
//
// Scenario #2 : X-Forwarded-For
// Scenario #2 : X-Forwarded-For without proxy.
//
request: request2,
hasProxy: false,
expected: net.ParseIP("8.8.8.8"),
},
{
//
// Scenario #3 : X-Real-IP without proxy.
//
request: request3,
hasProxy: false,
expected: net.ParseIP("8.8.8.8"),
},
{
//
// Scenario #4 : RemoteAddr with proxy.
//
request: request1,
hasProxy: true,
expected: net.ParseIP("8.8.8.8"),
},
{
//
// Scenario #5 : X-Forwarded-For with proxy.
//
request: request2,
hasProxy: true,
expected: net.ParseIP("9.9.9.9"),
},
{
//
// Scenario #3 : X-Real-IP
// Scenario #6 : X-Real-IP with proxy.
//
request: request3,
hasProxy: true,
expected: net.ParseIP("6.6.6.6"),
},
}

for i, scenario := range scenarios {
message := fmt.Sprintf("Scenario #%d", (i + 1))
ip := limiter.GetIP(scenario.request)
ip := limiter.GetIP(scenario.request, scenario.hasProxy)
is.Equal(scenario.expected, ip, message)
}
}
Expand Down Expand Up @@ -94,34 +122,62 @@ func TestGetIPKey(t *testing.T) {

scenarios := []struct {
request *http.Request
hasProxy bool
expected string
}{
{
//
// Scenario #1 : RemoteAddr
// Scenario #1 : RemoteAddr without proxy.
//
request: request1,
hasProxy: false,
expected: "8.8.8.8",
},
{
//
// Scenario #2 : X-Forwarded-For without proxy.
//
request: request2,
hasProxy: false,
expected: "8.8.8.8",
},
{
//
// Scenario #3 : X-Real-IP without proxy.
//
request: request3,
hasProxy: false,
expected: "8.8.8.8",
},
{
//
// Scenario #4 : RemoteAddr without proxy.
//
request: request1,
hasProxy: true,
expected: "8.8.8.8",
},
{
//
// Scenario #2 : X-Forwarded-For
// Scenario #5 : X-Forwarded-For without proxy.
//
request: request2,
hasProxy: true,
expected: "9.9.9.9",
},
{
//
// Scenario #3 : X-Real-IP
// Scenario #6 : X-Real-IP without proxy.
//
request: request3,
hasProxy: true,
expected: "6.6.6.6",
},
}

for i, scenario := range scenarios {
message := fmt.Sprintf("Scenario #%d", (i + 1))
key := limiter.GetIPKey(scenario.request)
key := limiter.GetIPKey(scenario.request, scenario.hasProxy)
is.Equal(scenario.expected, key, message)
}
}

0 comments on commit 0d25c13

Please sign in to comment.