diff --git a/internal/server/middleware.go b/internal/server/middleware.go index 7eb76f91..c0608325 100644 --- a/internal/server/middleware.go +++ b/internal/server/middleware.go @@ -1,6 +1,7 @@ package server import ( + "bytes" "errors" "fmt" "net" @@ -100,22 +101,88 @@ func getClientIPFromRequest(proxyCount int, r *http.Request) string { } } - remoteIP, _, err := net.SplitHostPort(readUserIP(r)) + remoteIP, _, err := net.SplitHostPort(getIPAdress(r)) if err != nil { remoteIP = r.RemoteAddr } return remoteIP } -func readUserIP(r *http.Request) string { - IPAddress := r.Header.Get("X-Real-Ip") - if IPAddress == "" { - IPAddress = r.Header.Get("X-Forwarded-For") +//ipRange - a structure that holds the start and end of a range of ip addresses +type ipRange struct { + start net.IP + end net.IP +} + +// inRange - check to see if a given ip address is within a range given +func inRange(r ipRange, ipAddress net.IP) bool { + // strcmp type byte comparison + if bytes.Compare(ipAddress, r.start) >= 0 && bytes.Compare(ipAddress, r.end) < 0 { + return true + } + return false +} + +var privateRanges = []ipRange{ + ipRange{ + start: net.ParseIP("10.0.0.0"), + end: net.ParseIP("10.255.255.255"), + }, + ipRange{ + start: net.ParseIP("100.64.0.0"), + end: net.ParseIP("100.127.255.255"), + }, + ipRange{ + start: net.ParseIP("172.16.0.0"), + end: net.ParseIP("172.31.255.255"), + }, + ipRange{ + start: net.ParseIP("192.0.0.0"), + end: net.ParseIP("192.0.0.255"), + }, + ipRange{ + start: net.ParseIP("192.168.0.0"), + end: net.ParseIP("192.168.255.255"), + }, + ipRange{ + start: net.ParseIP("198.18.0.0"), + end: net.ParseIP("198.19.255.255"), + }, +} + + +// isPrivateSubnet - check to see if this ip is in a private subnet +func isPrivateSubnet(ipAddress net.IP) bool { + // my use case is only concerned with ipv4 atm + if ipCheck := ipAddress.To4(); ipCheck != nil { + // iterate over all our ranges + for _, r := range privateRanges { + // check if this ip is in a private range + if inRange(r, ipAddress){ + return true + } + } } - if IPAddress == "" { - IPAddress = r.RemoteAddr + return false +} + +func getIPAdress(r *http.Request) string { + for _, h := range []string{"X-Forwarded-For", "X-Real-Ip"} { + addresses := strings.Split(r.Header.Get(h), ",") + // march from right to left until we get a public address + // that will be the address right before our proxy. + for i := len(addresses) -1 ; i >= 0; i-- { + ip := strings.TrimSpace(addresses[i]) + // header can contain spaces too, strip those out. + realIP := net.ParseIP(ip) + if !realIP.IsGlobalUnicast() || isPrivateSubnet(realIP) { + // bad address, go to next + continue + } + return ip + } } - return IPAddress + return "" } type Captcha struct {