From 6eea5ba94f1cd5c884cecef6e1545ebfdbbd953f Mon Sep 17 00:00:00 2001 From: Ryan P Date: Fri, 18 Feb 2022 07:33:04 -0500 Subject: [PATCH] Add HTTP port rules to HTTPFilter (#185) --- rcap/http_rulefilter.go | 52 ++++++++++++++++++++++ rcap/http_rulefilter_test.go | 83 ++++++++++++++++++++++++++++++++++++ util/util.go | 10 +++++ util/util_test.go | 12 ++++++ 4 files changed, 157 insertions(+) diff --git a/rcap/http_rulefilter.go b/rcap/http_rulefilter.go index 26aff363..fa8be83b 100644 --- a/rcap/http_rulefilter.go +++ b/rcap/http_rulefilter.go @@ -3,9 +3,13 @@ package rcap import ( "net" "net/http" + "net/url" + "strconv" "strings" "github.com/pkg/errors" + + "github.com/suborbital/reactr/util" ) var ( @@ -13,17 +17,22 @@ var ( ErrIPsDisallowed = errors.New("requests to IP addresses are disallowed") ErrPrivateDisallowed = errors.New("requests to private IP address ranges are disallowed") ErrDomainDisallowed = errors.New("requests to this domain are disallowed") + ErrPortDisallowed = errors.New("requests to this port are disallowed") ) // HTTPRules is a set of rules that governs use of the HTTP capability type HTTPRules struct { AllowedDomains []string `json:"allowedDomains" yaml:"allowedDomains"` BlockedDomains []string `json:"blockedDomains" yaml:"blockedDomains"` + AllowedPorts []int `json:"allowedPorts" yaml:"allowedPorts"` + BlockedPorts []int `json:"blockedPorts" yaml:"blockedPorts"` AllowIPs bool `json:"allowIPs" yaml:"allowIPs"` AllowPrivate bool `json:"allowPrivate" yaml:"allowPrivate"` AllowHTTP bool `json:"allowHTTP" yaml:"allowHTTP"` } +var standardPorts = []int{80, 443} + // requestIsAllowed returns a non-nil error if the provided request is not allowed to proceed func (h HTTPRules) requestIsAllowed(req *http.Request) error { // Hostname removes port numbers as well as IPv6 [ and ] @@ -35,6 +44,11 @@ func (h HTTPRules) requestIsAllowed(req *http.Request) error { } } + // Evaluate port access rules + if err := h.portAllowed(req.URL); err != nil { + return err + } + // determine if the passed-in host is an IP address isRawIP := net.ParseIP(req.URL.Hostname()) != nil if !h.AllowIPs && isRawIP { @@ -90,6 +104,44 @@ func (h HTTPRules) requestIsAllowed(req *http.Request) error { return nil } +// portAllowed evaluates port allowance rules +func (h HTTPRules) portAllowed(url *url.URL) error { + // Backward Compatibility: + // Allow all ports if no allow/block list has been configured + if len(h.AllowedPorts)+len(h.BlockedPorts) == 0 { + return nil + } + + port, err := readPort(url) + if err != nil { + return ErrPortDisallowed + } + + if util.ContainsInt(port, h.BlockedPorts) { + return ErrPortDisallowed + } + + for _, p := range append(standardPorts, h.AllowedPorts...) { + if p == port { + return nil + } + } + + return ErrPortDisallowed +} + +// readPort returns normalized URL port +func readPort(url *url.URL) (int, error) { + if url.Port() == "" { + if url.Scheme == "https" { + return 443, nil + } + return 80, nil + } + + return strconv.Atoi(url.Port()) +} + // returns nil if the host does not resolve to an IP in a private range // returns ErrPrivateDisallowed if it does func resolvesToPrivate(host string) error { diff --git a/rcap/http_rulefilter_test.go b/rcap/http_rulefilter_test.go index 23b1d12f..c99e79ba 100644 --- a/rcap/http_rulefilter_test.go +++ b/rcap/http_rulefilter_test.go @@ -187,6 +187,89 @@ func TestBlockedDomains(t *testing.T) { }) } +func TestAllowedPorts(t *testing.T) { + rules := defaultHTTPRules() + rules.AllowedPorts = []int{8080} + + t.Run("standard http port allowed", func(t *testing.T) { + req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) + + if err := rules.requestIsAllowed(req); err != nil { + t.Error("error occurred, should not have:", err) + } + }) + + t.Run("standard https port allowed", func(t *testing.T) { + req, _ := http.NewRequest(http.MethodGet, "https://example.com", nil) + + if err := rules.requestIsAllowed(req); err != nil { + t.Error("error occurred, should not have:", err) + } + }) + + t.Run("port 8080 allowed", func(t *testing.T) { + req, _ := http.NewRequest(http.MethodGet, "http://example.com:8080", nil) + + if err := rules.requestIsAllowed(req); err != nil { + t.Error("error occurred, should not have:", err) + } + }) + + t.Run("port 8088 disallowed", func(t *testing.T) { + req, _ := http.NewRequest(http.MethodGet, "http://example.com:8088", nil) + + if err := rules.requestIsAllowed(req); err == nil { + t.Error("error did not occur, should have") + } + }) +} + +func TestBlockedPorts(t *testing.T) { + rules := defaultHTTPRules() + rules.AllowedPorts = []int{8081, 8082} + rules.BlockedPorts = []int{80, 443, 8080, 8081} + + t.Run("standard HTTP port disallowed", func(t *testing.T) { + req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) + + if err := rules.requestIsAllowed(req); err == nil { + t.Error("error did not occur, should have") + } + }) + + t.Run("standard HTTPS port disallowed", func(t *testing.T) { + req, _ := http.NewRequest(http.MethodGet, "https://example.com", nil) + + if err := rules.requestIsAllowed(req); err == nil { + t.Error("error did not occur, should have") + } + }) + + t.Run("port 8080 disallowed", func(t *testing.T) { + req, _ := http.NewRequest(http.MethodGet, "http://example.com", nil) + + if err := rules.requestIsAllowed(req); err == nil { + t.Error("error did not occur, should have") + } + }) + + t.Run("blocked list takes precedence over allow list", func(t *testing.T) { + req, _ := http.NewRequest(http.MethodGet, "http://example.com:8081", nil) + + if err := rules.requestIsAllowed(req); err == nil { + t.Error("error did not occur, should have") + } + }) + + t.Run("port 8082 allowed", func(t *testing.T) { + req, _ := http.NewRequest(http.MethodGet, "http://example.com:8082", nil) + + if err := rules.requestIsAllowed(req); err != nil { + t.Error("error occurred, should not have:", err) + } + }) +} + func TestBlockedWithCNAME(t *testing.T) { rules := defaultHTTPRules() rules.BlockedDomains = []string{"hosting.gitbook.io"} diff --git a/util/util.go b/util/util.go index 5dffc472..890e08b5 100644 --- a/util/util.go +++ b/util/util.go @@ -20,3 +20,13 @@ func GenerateResultID() string { return id } + +// ContainsInt returns true if value present in int slice +func ContainsInt(value int, values []int) bool { + for _, p := range values { + if p == value { + return true + } + } + return false +} diff --git a/util/util_test.go b/util/util_test.go index 065e02c1..2bcc9afe 100644 --- a/util/util_test.go +++ b/util/util_test.go @@ -11,3 +11,15 @@ func TestGenerateResultID(t *testing.T) { t.Errorf("id has length %d, expected 24", len(id)) } } + +func TestContainsInt(t *testing.T) { + container := []int{1, 2, 3, 4} + + if !ContainsInt(3, container) { + t.Errorf("expected value not found in container") + } + + if ContainsInt(5, container) { + t.Errorf("should not have found value in container") + } +}