Skip to content

Commit

Permalink
Merge pull request #4 from dkijkuit/feature/code-refactoring
Browse files Browse the repository at this point in the history
Refactored code for maintenance and updated docs
  • Loading branch information
dkijkuit authored Dec 10, 2020
2 parents e9ccb0c + cc2d08b commit 943ed25
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 38 deletions.
16 changes: 8 additions & 8 deletions README.MD
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Traefik 2 check request headers middleware plugin

This plugin checks the incoming request for specific headers and their values to be present and matching the configuration. If the request does not validate against the configured headers, the middleware will return a 403 Forbidden status code.
This plugin checks the incoming request for specific headers and their values to be present and matching the configuration. If the request does not validate against the configured headers, the middleware will return a 403 Forbidden status code. This is can also be used to check client certificate information in combination with the [PassTLSClientCert](https://doc.traefik.io/traefik/middlewares/passtlsclientcert/) Traefik middleware, details can be found at the end of this document.

## Dev `traefik.yml` configuration file for traefik

Expand Down Expand Up @@ -98,14 +98,14 @@ Should return a 200 showing details about the request.

Supported configurations per header

| Setting | Allowed values | Description |
|---|---|---|
| name | string | Name of the request header |
| Setting | Allowed values | Description |
| :-- | :-- | :-- |
| name | string | Name of the request header |
| matchtype | one, all | Match on all values or one of the values specified. The value 'all' is only allowed in combination with the 'contains' setting.|
| values | []string | A list of allowed values which are matched against the request header value|
| contains | boolean | If set to true (default false), the request is allowed if the rtequest header value contains the value specified in the configuration |
| required | boolean | If set to false (default true), the request is allowed if the header is absent or the value is empty|
| urldecode | boolean | If set to true (default false), the value of the request header will be URL decoded before further processing with the plugin. This is useful when using this plugin with the [PassTLSClientCert](https://doc.traefik.io/traefik/middlewares/passtlsclientcert/) middleware that Traefik offers.
| values | []string | A list of allowed values which are matched against the request header value|
| contains | boolean | If set to true (default false), the request is allowed if the rtequest header value contains the value specified in the configuration |
| required | boolean | If set to false (default true), the request is allowed if the header is absent or the value is empty|
| urldecode | boolean | If set to true (default false), the value of the request header will be URL decoded before further processing with the plugin. This is useful when using this plugin with the [PassTLSClientCert](https://doc.traefik.io/traefik/middlewares/passtlsclientcert/) middleware that Traefik offers.

#
## Example 1 config
Expand Down
73 changes: 44 additions & 29 deletions header_match.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ type HeaderMatch struct {
type MatchType string

const (
//MatchAll requires all values to be matched
MatchAll MatchType = "all"
//MatchOne requires only one value to be matched
MatchOne MatchType = "one"
)

Expand Down Expand Up @@ -84,44 +86,21 @@ func (a *HeaderMatch) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
headersValid := true

for _, vHeader := range a.headers {

reqHeaderVal := req.Header.Get(vHeader.Name)

if vHeader.IsURLDecode() {
reqHeaderVal, _ = url.QueryUnescape(reqHeaderVal)
}

if vHeader.IsContains() && reqHeaderVal != "" {
matchCount := 0
for _, value := range vHeader.Values {
if strings.Contains(reqHeaderVal, value) {
matchCount++
}
}

if vHeader.MatchType == string(MatchOne) && matchCount == 0 {
headersValid = false
break
}
if vHeader.MatchType == string(MatchAll) && matchCount != len(vHeader.Values) {
headersValid = false
break
}
headersValid = checkContains(&reqHeaderVal, &vHeader)
} else {
matchCount := 0
for _, value := range vHeader.Values {
if reqHeaderVal == value {
matchCount++
}

if !vHeader.IsRequired() && reqHeaderVal == "" {
matchCount++
}
}
headersValid = checkRequired(&reqHeaderVal, &vHeader)
}

if matchCount == 0 {
headersValid = false
break
}
if !headersValid {
break
}
}

Expand All @@ -132,6 +111,42 @@ func (a *HeaderMatch) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
}

func checkContains(requestValue *string, vHeader *SingleHeader) bool {
matchCount := 0
for _, value := range vHeader.Values {
if strings.Contains(*requestValue, value) {
matchCount++
}
}

if matchCount == 0 {
return false
} else if vHeader.MatchType == string(MatchAll) && matchCount != len(vHeader.Values) {
return false
}

return true
}

func checkRequired(requestValue *string, vHeader *SingleHeader) bool {
matchCount := 0
for _, value := range vHeader.Values {
if *requestValue == value {
matchCount++
}

if !vHeader.IsRequired() && *requestValue == "" {
matchCount++
}
}

if matchCount == 0 {
return false
}

return true
}

//IsURLDecode checks whether a header value should be url decoded first before testing it
func (s *SingleHeader) IsURLDecode() bool {
if s.URLDecode == nil || *s.URLDecode == false {
Expand Down
13 changes: 12 additions & 1 deletion header_match_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,17 @@ func executeTest(t *testing.T, requestHeaders map[string]string, expectedResultC
Contains: &contains,
URLDecode: &urlDecode,
},
{
Name: "testContainsNotRequired",
Values: []string{
"value_not_important",
"value_not_important_2",
},
MatchType: string(checkheaders.MatchOne),
Required: &not_required,
Contains: &contains,
URLDecode: &urlDecode,
},
}

ctx := context.Background()
Expand All @@ -142,6 +153,6 @@ func executeTest(t *testing.T, requestHeaders map[string]string, expectedResultC
handler.ServeHTTP(recorder, req)

if recorder.Result().StatusCode != expectedResultCode {
t.Errorf("Unexpected response status code: %d, expected: %d", recorder.Result().StatusCode, expectedResultCode)
t.Errorf("Unexpected response status code: %d, expected: %d for incoming request headers: %s", recorder.Result().StatusCode, expectedResultCode, requestHeaders)
}
}

0 comments on commit 943ed25

Please sign in to comment.