diff --git a/README.md b/README.md index 5f21dc1..2ab22b4 100644 --- a/README.md +++ b/README.md @@ -38,24 +38,25 @@ docker run --rm -ti \ ### Configuration When running the Proxy, the following flags can be used (none are required) : - -| Flag (or short form) | Type | Description | Default | -|-------------------------------|----------|----------------------------------------------------------|---------| -| `verbose` or `v` | Boolean | Enable additional logging, implies all the log-* options | `False` | -| `log-failed-requests` | Boolean | Log 4xx and 5xx response body | `False` | -| `log-signing-process` | Boolean | Log sigv4 signing process | `False` | -| `unsigned-payload` | Boolean | Prevent signing of the payload" | `False` | -| `port` | String | Port to serve http on | `8080` | -| `strip` or `s` | String | Headers to strip from incoming request | None | -| `duplicate-headers` | String | Duplicate headers to an X-Original- prefix name | None | -| `role-arn` | String | Amazon Resource Name (ARN) of the role to assume | None | -| `name` | String | AWS Service to sign for | None | -| `sign-host` | String | Host to sign for | None | -| `host` | String | Host to proxy to | None | -| `region` | String | AWS region to sign for | None | -| `upstream-url-scheme` | String | Protocol to proxy with | https | -| `no-verify-ssl` | Boolean | Disable peer SSL certificate validation | `False` | -| `transport.idle-conn-timeout` | Duration | Idle timeout to the upstream service | `40s` | +s", " +| Flag (or short form) | Type | Description | Default | +|-------------------------------|----------|------------------------------------------------------------|---------| +| `verbose` or `v` | Boolean | Enable additional logging, implies all the log-* options | `False` | +| `log-failed-requests` | Boolean | Log 4xx and 5xx response body | `False` | +| `log-signing-process` | Boolean | Log sigv4 signing process | `False` | +| `unsigned-payload` | Boolean | Prevent signing of the payload" | `False` | +| `port` | String | Port to serve http on | `8080` | +| `strip` or `s` | String | Headers to strip from incoming request | None | +| `custom-headers` | String | Comma-separated list of custom headers in key=value format | None | +| `duplicate-headers` | String | Duplicate headers to an X-Original- prefix name | None | +| `role-arn` | String | Amazon Resource Name (ARN) of the role to assume | None | +| `name` | String | AWS Service to sign for | None | +| `sign-host` | String | Host to sign for | None | +| `host` | String | Host to proxy to | None | +| `region` | String | AWS region to sign for | None | +| `upstream-url-scheme` | String | Protocol to proxy with | https | +| `no-verify-ssl` | Boolean | Disable peer SSL certificate validation | `False` | +| `transport.idle-conn-timeout` | Duration | Idle timeout to the upstream service | `40s` | ## Examples diff --git a/cmd/aws-sigv4-proxy/main.go b/cmd/aws-sigv4-proxy/main.go index c4817cd..1ea2ab2 100644 --- a/cmd/aws-sigv4-proxy/main.go +++ b/cmd/aws-sigv4-proxy/main.go @@ -19,7 +19,9 @@ import ( "crypto/tls" "net/http" "os" + "reflect" "strconv" + "strings" "time" "aws-sigv4-proxy/handler" @@ -40,6 +42,7 @@ var ( logSinging = kingpin.Flag("log-signing-process", "Log sigv4 signing process").Bool() port = kingpin.Flag("port", "Port to serve http on").Default(":8080").String() strip = kingpin.Flag("strip", "Headers to strip from incoming request").Short('s').Strings() + customHeaders = kingpin.Flag("custom-headers", "Comma-separated list of custom headers in key=value format").String() duplicateHeaders = kingpin.Flag("duplicate-headers", "Duplicate headers to an X-Original- prefix name").Strings() roleArn = kingpin.Flag("role-arn", "Amazon Resource Name (ARN) of the role to assume").String() signingNameOverride = kingpin.Flag("name", "AWS Service to sign for").String() @@ -68,6 +71,27 @@ func main() { log.SetLevel(log.DebugLevel) } + // Initialize an http.Header object for custom headers + customHeadersParsed := make(http.Header) + + // Parse and add custom headers if provided + if *customHeaders != "" { + // Split the headers into key-value pairs + headers := strings.Split(*customHeaders, ",") + + for _, h := range headers { + // Split each header into key and value + kv := strings.SplitN(h, "=", 2) + if len(kv) != 2 { + log.Warnf("Invalid header format: [%s], skipping", h) + continue + } + + // Add the header to the custom headers + customHeadersParsed.Add(kv[0], kv[1]) + } + } + sessionConfig := aws.Config{} if v := os.Getenv("AWS_STS_REGIONAL_ENDPOINTS"); len(v) == 0 { sessionConfig.STSRegionalEndpoint = endpoints.RegionalSTSEndpoint @@ -119,6 +143,7 @@ func main() { }, } + log.WithFields(log.Fields{"CcustomHeadersParsed": reflect.ValueOf(customHeadersParsed).MapKeys()}).Infof("Custom headers, values are redacted: %s", reflect.ValueOf(customHeadersParsed).MapKeys()) log.WithFields(log.Fields{"StripHeaders": *strip}).Infof("Stripping headers %s", *strip) log.WithFields(log.Fields{"DuplicateHeaders": *duplicateHeaders}).Infof("Duplicating headers %s", *duplicateHeaders) log.WithFields(log.Fields{"port": *port}).Infof("Listening on %s", *port) @@ -129,6 +154,7 @@ func main() { Signer: signer, Client: client, StripRequestHeaders: *strip, + CustomHeaders: customHeadersParsed, DuplicateRequestHeaders: *duplicateHeaders, SigningNameOverride: *signingNameOverride, SigningHostOverride: *signingHostOverride, diff --git a/handler/proxy_client.go b/handler/proxy_client.go index 492edf2..e591260 100644 --- a/handler/proxy_client.go +++ b/handler/proxy_client.go @@ -39,6 +39,7 @@ type ProxyClient struct { Signer *v4.Signer Client Client StripRequestHeaders []string + CustomHeaders http.Header DuplicateRequestHeaders []string SigningNameOverride string SigningHostOverride string @@ -226,6 +227,9 @@ func (p *ProxyClient) Do(req *http.Request) (*http.Response, error) { // Add origin headers after request is signed (no overwrite) copyHeaderWithoutOverwrite(proxyReq.Header, req.Header) + // Add custom headers (no overwrite) + copyHeaderWithoutOverwrite(proxyReq.Header, p.CustomHeaders) + if log.GetLevel() == log.DebugLevel { proxyReqDump, err := httputil.DumpRequest(proxyReq, true) if err != nil { diff --git a/handler/proxy_client_test.go b/handler/proxy_client_test.go index edaf1e6..a1f0d0f 100644 --- a/handler/proxy_client_test.go +++ b/handler/proxy_client_test.go @@ -419,6 +419,66 @@ func TestProxyClient_Do(t *testing.T) { }, }, }, + { + name: "should add the custom header", + request: &http.Request{ + Method: "GET", + URL: &url.URL{}, + Host: "execute-api.us-west-2.amazonaws.com", + Body: nil, + Header: http.Header{ + "User-Agent": []string{"customAgent"}, + }, + }, + proxyClient: &ProxyClient{ + Signer: v4.NewSigner(credentials.NewCredentials(&mockProvider{})), + Client: &mockHTTPClient{}, + DuplicateRequestHeaders: []string{"NonExistentHeader"}, + CustomHeaders: http.Header{"Custom-Header": []string{"customValue"}}, + }, + want: &want{ + resp: &http.Response{}, + err: nil, + request: &http.Request{ + Host: "execute-api.us-west-2.amazonaws.com", + Header: http.Header{ + "User-Agent": []string{"customAgent"}, + //Ensure the custom header is present + "Custom-Header": []string{"customValue"}, + }, + }, + }, + }, + { + name: "should not overwrite origin header with a custom header", + request: &http.Request{ + Method: "GET", + URL: &url.URL{}, + Host: "execute-api.us-west-2.amazonaws.com", + Header: http.Header{ + "Custom-Header": []string{"customValue"}, + "User-Agent": []string{"customAgent"}, + }, + Body: nil, + }, + proxyClient: &ProxyClient{ + Signer: v4.NewSigner(credentials.NewCredentials(&mockProvider{})), + Client: &mockHTTPClient{}, + CustomHeaders: http.Header{"Custom-Header": []string{"customValueCustom"}}, + }, + want: &want{ + resp: &http.Response{}, + err: nil, + request: &http.Request{ + Host: "execute-api.us-west-2.amazonaws.com", + Header: http.Header{ + //Ensure the custom header doesn't overwrite an existing header + "Custom-Header": []string{"customValue"}, + "User-Agent": []string{"customAgent"}, + }, + }, + }, + }, } for _, tt := range tests {