diff --git a/.travis.yml b/.travis.yml index e54eb1a..ad86a72 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,4 +1,4 @@ language: go go: - - "1.11.x" + - "1.18.x" - "1.x" diff --git a/CHANGELOG.md b/CHANGELOG.md index d7dd1e5..108dc95 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,8 +1,27 @@ # Changelog -## [0.12.1] - 2019-09-26 +## [1.0.0] - 2023-03-02 +### Fixed +- API Gateway V2: Fixed response header support. +- API Gateway V2: Fixed handling request cookies. +- API Gateway V2: Fixed multi-value query parameters. +- ALB: Fixed double escaping of query parameters. + +### Changed +- `RequestTypeAPIGateway` renamed to `RequestTypeAPIGatewayV1`. +- `ProxyRequestFromContext` renamed to `APIGatewayV1RequestFromContext`. +- `APIGatewayV2HTTPRequestFromContext` renamed to `APIGatewayV2RequestFromContext`. +- `TargetGroupRequestFromContext` renamed to `ALBRequestFromContext`. +- Improved unit tests. +- Go 1.18 is the minimum supported version now. + +## [0.13.0] - 2022-01-08 ### Added -- Fixed compatibility with Go versions older than 1.13. +- API Gateway V2 support (@a-h). + +## [0.12.1] - 2019-09-26 +### Fixed +- Compatibility with Go versions older than 1.13. ## [0.12.0] - 2019-09-26 ### Added @@ -16,7 +35,7 @@ ### Changed - Set RequestURI on request (@RossHammer). - Unescape Path (@RossHammer). -- Multi-value header support implemented using APIGatewayProxyResponse.MultiValueHeaders. +- Multi-value header support implemented using `APIGatewayProxyResponse.MultiValueHeaders`. ## [0.9] - 2018-12-10 ### Added diff --git a/README.md b/README.md index db5dc9d..7d345d5 100644 --- a/README.md +++ b/README.md @@ -15,10 +15,6 @@ import ( "github.com/akrylysov/algnhsa" ) -func indexHandler(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("index")) -} - func addHandler(w http.ResponseWriter, r *http.Request) { f, _ := strconv.Atoi(r.FormValue("first")) s, _ := strconv.Atoi(r.FormValue("second")) @@ -27,14 +23,13 @@ func addHandler(w http.ResponseWriter, r *http.Request) { } func contextHandler(w http.ResponseWriter, r *http.Request) { - proxyReq, ok := algnhsa.ProxyRequestFromContext(r.Context()) + lambdaEvent, ok := algnhsa.APIGatewayV2RequestFromContext(r.Context()) if ok { - fmt.Fprint(w, proxyReq.RequestContext.AccountID) + fmt.Fprint(w, lambdaEvent.RequestContext.AccountID) } } func main() { - http.HandleFunc("/", indexHandler) http.HandleFunc("/add", addHandler) http.HandleFunc("/context", contextHandler) algnhsa.ListenAndServe(http.DefaultServeMux, nil) @@ -56,26 +51,48 @@ import ( func main() { r := chi.NewRouter() r.Get("/", func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("index")) + w.Write([]byte("hi")) }) algnhsa.ListenAndServe(r, nil) } ``` -## Setting up API Gateway +## Deployment + +First, build your Go application for Linux and zip it: + +```bash +GOOS=linux GOARCH=amd64 go build -o handler +zip handler.zip handler +``` + +AWS provides plenty of ways to expose a Lambda function to the internet. + +### Lambda Function URL + +This is the easier way to deploy your Lambda function as an HTTP endpoint. +It only requires going to the "Function URL" section of the Lambda function configuration and clicking "Configure Function URL". + +### API Gateway + +#### HTTP API + +1. Create a new HTTP API. + +2. Configure a catch-all `$default` route. + +#### REST API 1. Create a new REST API. 2. In the "Resources" section create a new `ANY` method to handle requests to `/` (check "Use Lambda Proxy Integration"). - ![API Gateway index](https://akrylysov.github.io/algnhsa/apigateway-index.png) - 3. Add a catch-all `{proxy+}` resource to handle requests to every other path (check "Configure as proxy resource"). - ![API Gateway catch-all](https://akrylysov.github.io/algnhsa/apigateway-catchall.png) - -## Setting up ALB +### ALB 1. Create a new ALB and point it to your Lambda function. -2. In the target group settings enable "Multi value headers". +2. In the target group settings in the "Attributes" section enable "Multi value headers". + +[AWS Documentation](https://docs.aws.amazon.com/lambda/latest/dg/services-alb.html) diff --git a/adapter.go b/adapter.go index ef0e4e5..f228d7e 100644 --- a/adapter.go +++ b/adapter.go @@ -3,6 +3,7 @@ package algnhsa import ( "context" "encoding/json" + "fmt" "net/http" "net/http/httptest" @@ -21,10 +22,16 @@ func (handler lambdaHandler) Invoke(ctx context.Context, payload []byte) ([]byte if err != nil { return nil, err } + if handler.opts.DebugLog { + fmt.Printf("Response: %+v", resp) + } return json.Marshal(resp) } func (handler lambdaHandler) handleEvent(ctx context.Context, payload []byte) (lambdaResponse, error) { + if handler.opts.DebugLog { + fmt.Printf("Request: %s", payload) + } eventReq, err := newLambdaRequest(ctx, payload, handler.opts) if err != nil { return lambdaResponse{}, err @@ -35,7 +42,7 @@ func (handler lambdaHandler) handleEvent(ctx context.Context, payload []byte) (l } w := httptest.NewRecorder() handler.httpHandler.ServeHTTP(w, r) - return newLambdaResponse(w, handler.opts.binaryContentTypeMap) + return newLambdaResponse(w, handler.opts.binaryContentTypeMap, eventReq.requestType) } // ListenAndServe starts the AWS Lambda runtime (aws-lambda-go lambda.Start) with a given handler. diff --git a/adapter_test.go b/adapter_test.go deleted file mode 100644 index c85dc20..0000000 --- a/adapter_test.go +++ /dev/null @@ -1,462 +0,0 @@ -package algnhsa - -import ( - "context" - "encoding/json" - "fmt" - "io/ioutil" - "net/http" - "reflect" - "testing" - - "github.com/aws/aws-lambda-go/events" - "github.com/stretchr/testify/assert" -) - -type adapterTestCase struct { - req lambdaRequest - opts *Options - resp lambdaResponse - apigwReq events.APIGatewayProxyRequest - albReq events.ALBTargetGroupRequest - expectedErr error -} - -var commonAdapterTestCases = []adapterTestCase{ - { - req: lambdaRequest{ - Path: "/html", - }, - resp: lambdaResponse{ - StatusCode: 200, - Body: "foo", - MultiValueHeaders: map[string][]string{"Content-Type": {"text/html; charset=utf-8"}}, - }, - }, - { - req: lambdaRequest{ - Path: "/text", - }, - resp: lambdaResponse{ - StatusCode: 200, - Body: "ok", - }, - }, - { - req: lambdaRequest{ - Path: "/query-params", - QueryStringParameters: map[string]string{ - "a": "1", - "b": "", - }, - MultiValueQueryStringParameters: map[string][]string{ - "b": {"2"}, - "c": {"31", "32", "33"}, - }, - }, - resp: lambdaResponse{ - StatusCode: 200, - Body: "a=[1], b=[2], c=[31 32 33], unknown=[]", - }, - }, - { - req: lambdaRequest{ - Path: "/path/encode%2Ftest%7C", - }, - resp: lambdaResponse{ - StatusCode: 200, - Body: "/path/encode/test|", - }, - }, - { - req: lambdaRequest{ - HTTPMethod: "POST", - Path: "/post-body", - Body: "foobar", - }, - resp: lambdaResponse{ - StatusCode: 200, - Body: "foobar", - }, - }, - { - req: lambdaRequest{ - HTTPMethod: "POST", - Path: "/post-body", - Body: "Zm9vYmFy", - IsBase64Encoded: true, - }, - resp: lambdaResponse{ - StatusCode: 200, - Body: "foobar", - }, - }, - { - req: lambdaRequest{ - HTTPMethod: "POST", - Path: "/form", - MultiValueHeaders: map[string][]string{ - "Content-Type": {"application/x-www-form-urlencoded"}, - "Content-Length": {"19"}, - }, - Body: "f=foo&s=bar&xyz=123", - }, - resp: lambdaResponse{ - StatusCode: 200, - Body: "foobar", - }, - }, - { - req: lambdaRequest{ - Path: "/status", - }, - resp: lambdaResponse{ - StatusCode: 204, - MultiValueHeaders: map[string][]string{"Content-Type": {"image/gif"}}, - }, - }, - { - req: lambdaRequest{ - Path: "/headers", - Headers: map[string]string{ - "X-a": "1", - "x-b": "2", - }, - MultiValueHeaders: map[string][]string{ - "x-B": {"21", "22"}, - }, - }, - resp: lambdaResponse{ - StatusCode: 200, - MultiValueHeaders: map[string][]string{ - "Content-Type": {"text/plain; charset=utf-8"}, - "X-Bar": {"baz"}, - "X-Y": {"1", "2"}, - }, - Body: "ok", - }, - }, - { - req: lambdaRequest{ - Path: "/text", - }, - opts: &Options{ - BinaryContentTypes: []string{"text/plain; charset=utf-8"}, - }, - resp: lambdaResponse{ - StatusCode: 200, - Body: "b2s=", - IsBase64Encoded: true, - }, - }, - { - req: lambdaRequest{ - Path: "/text", - }, - opts: &Options{ - BinaryContentTypes: []string{"*/*"}, - }, - resp: lambdaResponse{ - StatusCode: 200, - Body: "b2s=", - IsBase64Encoded: true, - }, - }, - { - req: lambdaRequest{ - Path: "/text", - }, - opts: &Options{ - BinaryContentTypes: []string{"text/html; charset=utf-8"}, - }, - resp: lambdaResponse{ - StatusCode: 200, - Body: "ok", - }, - }, - { - req: lambdaRequest{ - Path: "/404", - }, - resp: lambdaResponse{ - StatusCode: 404, - Body: "404 page not found\n", - MultiValueHeaders: map[string][]string{ - "Content-Type": {"text/plain; charset=utf-8"}, - "X-Content-Type-Options": {"nosniff"}, - }, - }, - }, - { - req: lambdaRequest{ - Path: "/hostname", - Headers: map[string]string{ - "Host": "bar", - }, - }, - resp: lambdaResponse{ - StatusCode: 200, - MultiValueHeaders: map[string][]string{ - "Content-Type": {"text/plain; charset=utf-8"}, - }, - Body: "bar", - }, - }, - { - req: lambdaRequest{ - Path: "/requesturi", - QueryStringParameters: map[string]string{ - "foo": "bar", - }, - }, - resp: lambdaResponse{ - StatusCode: 200, - Body: "/requesturi?foo=bar", - }, - }, -} - -var apigwAdapterTestCases = []adapterTestCase{ - { - req: lambdaRequest{ - Path: "/apigw/text", - }, - apigwReq: events.APIGatewayProxyRequest{ - PathParameters: map[string]string{ - "proxy": "text", - }, - }, - opts: &Options{ - UseProxyPath: true, - }, - resp: lambdaResponse{ - StatusCode: 200, - Body: "ok", - }, - }, - { - req: lambdaRequest{ - Path: "/apigw/context", - }, - apigwReq: events.APIGatewayProxyRequest{ - RequestContext: events.APIGatewayProxyRequestContext{ - AccountID: "foo", - }, - }, - resp: lambdaResponse{ - StatusCode: 200, - Body: "ok", - }, - }, - { - req: lambdaRequest{ - Path: "/apigw/wrongtype", - }, - opts: &Options{ - RequestType: RequestTypeALB, - }, - expectedErr: errALBUnexpectedRequest, - }, -} - -var albAdapterTestCases = []adapterTestCase{ - { - req: lambdaRequest{ - Path: "/alb/context", - }, - albReq: events.ALBTargetGroupRequest{ - RequestContext: events.ALBTargetGroupRequestContext{ - ELB: events.ELBContext{ - TargetGroupArn: "foo", - }, - }, - }, - resp: lambdaResponse{ - StatusCode: 200, - Body: "ok", - }, - }, - { - req: lambdaRequest{ - Path: "/alb/wrongtype", - }, - opts: &Options{ - RequestType: RequestTypeAPIGateway, - }, - expectedErr: errAPIGatewayUnexpectedRequest, - }, -} - -func testHandle(t *testing.T, testCases []adapterTestCase, testMode RequestType, requestType RequestType) { - t.Helper() - asrt := assert.New(t) - - r := http.NewServeMux() - - // Common handlers - r.HandleFunc("/html", func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("foo")) - }) - r.HandleFunc("/text", func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("ok")) - }) - r.HandleFunc("/query-params", func(w http.ResponseWriter, r *http.Request) { - r.ParseForm() - fmt.Fprintf(w, "a=%s, b=%s, c=%s, unknown=%v", r.Form["a"], r.Form["b"], r.Form["c"], r.Form["unknown"]) - }) - r.HandleFunc("/path/", func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(r.URL.Path)) - }) - r.HandleFunc("/post-body", func(w http.ResponseWriter, r *http.Request) { - if r.Method == "POST" { - body, err := ioutil.ReadAll(r.Body) - if err != nil { - fmt.Fprintf(w, "%v", err) - } else { - w.Write(body) - } - } - }) - r.HandleFunc("/form", func(w http.ResponseWriter, r *http.Request) { - if r.Method == "POST" { - w.Write([]byte(r.FormValue("f") + r.FormValue("s") + r.FormValue("unknown"))) - } - }) - r.HandleFunc("/status", func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "image/gif") - w.WriteHeader(204) - }) - r.HandleFunc("/headers", func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("X-A") == "1" && reflect.DeepEqual(r.Header["X-B"], []string{"21", "22"}) { - w.Header().Set("X-Bar", "baz") - w.Header().Add("X-y", "1") - w.Header().Add("X-Y", "2") - w.Write([]byte("ok")) - } - }) - r.HandleFunc("/hostname", func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(r.Host)) - }) - r.HandleFunc("/requesturi", func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(r.RequestURI)) - }) - - // APIGateway specific handlers - r.HandleFunc("/apigw/context", func(w http.ResponseWriter, r *http.Request) { - expectedProxyReq := events.APIGatewayProxyRequest{ - HTTPMethod: "GET", - Path: "/apigw/context", - RequestContext: events.APIGatewayProxyRequestContext{ - AccountID: "foo", - }, - } - proxyReq, ok := ProxyRequestFromContext(r.Context()) - if ok && reflect.DeepEqual(expectedProxyReq, proxyReq) { - w.Write([]byte("ok")) - } - }) - - // ALB specific handlers - r.HandleFunc("/alb/context", func(w http.ResponseWriter, r *http.Request) { - expectedProxyReq := events.ALBTargetGroupRequest{ - HTTPMethod: "GET", - Path: "/alb/context", - MultiValueHeaders: map[string][]string{ - "X-Test": {"1"}, - }, - RequestContext: events.ALBTargetGroupRequestContext{ - ELB: events.ELBContext{ - TargetGroupArn: "foo", - }, - }, - } - targetReq, ok := TargetGroupRequestFromContext(r.Context()) - if ok && reflect.DeepEqual(expectedProxyReq, targetReq) { - w.Write([]byte("ok")) - } - }) - - for _, testCase := range testCases { - lambdaReq := testCase.req - if lambdaReq.HTTPMethod == "" { - lambdaReq.HTTPMethod = "GET" - } - - expectedResp := testCase.resp - if expectedResp.MultiValueHeaders == nil { - expectedResp.MultiValueHeaders = map[string][]string{"Content-Type": {"text/plain; charset=utf-8"}} - } - - lambdaPayload, err := json.Marshal(lambdaReq) - asrt.NoError(err) - - var payload []byte - if testMode == RequestTypeAPIGateway { - req := testCase.apigwReq - err = json.Unmarshal(lambdaPayload, &req) - asrt.NoError(err) - if req.RequestContext.AccountID == "" { - req.RequestContext.AccountID = "test" - } - payload, err = json.Marshal(req) - asrt.NoError(err) - } else { - req := testCase.albReq - err := json.Unmarshal(lambdaPayload, &req) - asrt.NoError(err) - if req.RequestContext.ELB.TargetGroupArn == "" { - req.RequestContext.ELB.TargetGroupArn = "test" - } - if req.MultiValueHeaders == nil { - req.MultiValueHeaders = map[string][]string{ - "X-Test": {"1"}, - } - } - payload, err = json.Marshal(req) - asrt.NoError(err) - } - - opts := testCase.opts - if opts == nil { - opts = defaultOptions - opts.RequestType = requestType - } - opts.setBinaryContentTypeMap() - handler := lambdaHandler{httpHandler: r, opts: opts} - resp, err := handler.handleEvent(context.Background(), payload) - if testCase.expectedErr == nil { - asrt.NoError(err) - asrt.EqualValues(expectedResp, resp, testCase) - } else { - asrt.Equal(testCase.expectedErr, err) - } - } -} - -func TestHandleAPIGatewayAuto(t *testing.T) { - var testCases []adapterTestCase - testCases = append(testCases, commonAdapterTestCases...) - testCases = append(testCases, apigwAdapterTestCases...) - testHandle(t, testCases, RequestTypeAPIGateway, RequestTypeAuto) -} - -func TestHandleAPIGatewayForced(t *testing.T) { - var testCases []adapterTestCase - testCases = append(testCases, commonAdapterTestCases...) - testCases = append(testCases, apigwAdapterTestCases...) - testHandle(t, testCases, RequestTypeAPIGateway, RequestTypeAPIGateway) -} - -func TestHandleALBAuto(t *testing.T) { - var testCases []adapterTestCase - testCases = append(testCases, commonAdapterTestCases...) - testCases = append(testCases, albAdapterTestCases...) - testHandle(t, testCases, RequestTypeALB, RequestTypeAuto) -} - -func TestHandleALBForced(t *testing.T) { - var testCases []adapterTestCase - testCases = append(testCases, commonAdapterTestCases...) - testCases = append(testCases, albAdapterTestCases...) - testHandle(t, testCases, RequestTypeALB, RequestTypeALB) -} diff --git a/alb.go b/alb.go index eccb967..d71abb7 100644 --- a/alb.go +++ b/alb.go @@ -4,13 +4,21 @@ import ( "context" "encoding/json" "errors" + "net/http" + "net/url" "strings" "github.com/aws/aws-lambda-go/events" ) +/* +AWS Documentation: + +- https://docs.aws.amazon.com/lambda/latest/dg/services-alb.html +*/ + var ( - errALBUnexpectedRequest = errors.New("expected ALBTargetGroupRequest") + errALBUnexpectedRequest = errors.New("expected ALBTargetGroupRequest event") errALBExpectedMultiValueHeaders = errors.New("expected multi value headers; enable Multi value headers in target group settings") ) @@ -36,18 +44,44 @@ func newALBRequest(ctx context.Context, payload []byte, opts *Options) (lambdaRe return lambdaRequest{}, errALBExpectedMultiValueHeaders } + for _, vals := range event.MultiValueQueryStringParameters { + for i, v := range vals { + unescaped, err := url.QueryUnescape(v) + if err != nil { + return lambdaRequest{}, err + } + vals[i] = unescaped + } + } + req := lambdaRequest{ HTTPMethod: event.HTTPMethod, Path: event.Path, - QueryStringParameters: event.QueryStringParameters, MultiValueQueryStringParameters: event.MultiValueQueryStringParameters, - Headers: event.Headers, MultiValueHeaders: event.MultiValueHeaders, Body: event.Body, IsBase64Encoded: event.IsBase64Encoded, SourceIP: getALBSourceIP(event), - Context: newTargetGroupRequestContext(ctx, event), + Context: context.WithValue(ctx, RequestTypeALB, event), + requestType: RequestTypeALB, } return req, nil } + +func newALBResponse(r *http.Response) (lambdaResponse, error) { + resp := lambdaResponse{ + MultiValueHeaders: r.Header, + } + return resp, nil +} + +// ALBRequestFromContext extracts the ALBTargetGroupRequest event from ctx. +func ALBRequestFromContext(ctx context.Context) (events.ALBTargetGroupRequest, bool) { + val := ctx.Value(RequestTypeALB) + if val == nil { + return events.ALBTargetGroupRequest{}, false + } + event, ok := val.(events.ALBTargetGroupRequest) + return event, ok +} diff --git a/alb_test.go b/alb_test.go new file mode 100644 index 0000000..74e9367 --- /dev/null +++ b/alb_test.go @@ -0,0 +1,195 @@ +package algnhsa + +import ( + "context" + "encoding/json" + "errors" + "github.com/stretchr/testify/assert" + "io" + "net/http" + "testing" + + "github.com/aws/aws-lambda-go/events" +) + +var albTestEvent = `{ + "requestContext": { + "elb": { + "targetGroupArn": "arn:aws:elasticloadbalancing:us-east-2:123456789012:targetgroup/lambda-279XGJDqGZ5rsrHC2Fjr/49e9d65c45c6791a" + } + }, + "httpMethod": "GET", + "path": "/lambda", + "multiValueQueryStringParameters": { "myKey": ["val1", "val2"] }, + "multiValueHeaders": { + "accept": ["text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8"], + "accept-encoding": ["gzip"], + "accept-language": ["en-US,en;q=0.9"], + "connection": ["keep-alive"], + "host": ["lambda-alb-123578498.us-east-2.elb.amazonaws.com"], + "upgrade-insecure-requests": ["1"], + "user-agent": ["Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/71.0.3578.98 Safari/537.36"], + "x-amzn-trace-id": ["Root=1-5c536348-3d683b8b04734faae651f476"], + "x-forwarded-for": ["72.12.164.125"], + "x-forwarded-port": ["80"], + "x-forwarded-proto": ["http"], + "x-imforwards": ["20"], + "cookie": ["cookie-name=cookie-value;Domain=myweb.com;Secure;HttpOnly","cookie-name=cookie-value;Expires=May 8, 2019"] + }, + "body": "", + "isBase64Encoded": false +} +` + +var expectedALBDump = RequestDebugDump{ + Method: "GET", + URL: struct { + Path string + RawPath string + }{ + Path: "/lambda", + RawPath: "", + }, + RequestURI: "/lambda?myKey=val1&myKey=val2", + Host: "lambda-alb-123578498.us-east-2.elb.amazonaws.com", + RemoteAddr: "72.12.164.125", + Header: map[string][]string{ + "Accept": {"text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8"}, + "Accept-Encoding": {"gzip"}, + "Accept-Language": {"en-US,en;q=0.9"}, + "Connection": {"keep-alive"}, + "Host": {"lambda-alb-123578498.us-east-2.elb.amazonaws.com"}, + "Cookie": {"cookie-name=cookie-value;Domain=myweb.com;Secure;HttpOnly", "cookie-name=cookie-value;Expires=May 8, 2019"}, + "Upgrade-Insecure-Requests": {"1"}, + "User-Agent": {"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/71.0.3578.98 Safari/537.36"}, + "X-Amzn-Trace-Id": {"Root=1-5c536348-3d683b8b04734faae651f476"}, + "X-Forwarded-For": {"72.12.164.125"}, + "X-Forwarded-Port": {"80"}, + "X-Forwarded-Proto": {"http"}, + "X-Imforwards": {"20"}, + }, + Form: map[string][]string{ + "myKey": {"val1", "val2"}, + }, + Body: "", +} + +func dumpALB(payload []byte, opts Options) (RequestDebugDump, error) { + lh := lambdaHandler{ + httpHandler: http.HandlerFunc(RequestDebugDumpHandler), + opts: &opts, + } + responseBytes, err := lh.Invoke(context.Background(), payload) + if err != nil { + return RequestDebugDump{}, err + } + var r events.ALBTargetGroupResponse + if err := json.Unmarshal(responseBytes, &r); err != nil { + return RequestDebugDump{}, err + } + if r.StatusCode != 200 { + return RequestDebugDump{}, errors.New("expected status code 200") + } + var dump RequestDebugDump + if err := json.Unmarshal([]byte(r.Body), &dump); err != nil { + return RequestDebugDump{}, err + } + if dump.ALBRequest.HTTPMethod != "GET" { + return RequestDebugDump{}, errors.New("expected method GET") + } + dump.ALBRequest = nil + return dump, nil +} + +func TestALBBase(t *testing.T) { + asrt := assert.New(t) + + dump, err := dumpALB([]byte(albTestEvent), Options{}) + asrt.NoError(err) + + asrt.Equal(expectedALBDump, dump) +} + +func TestALBBase64BodyRequest(t *testing.T) { + asrt := assert.New(t) + + event := events.ALBTargetGroupRequest{} + asrt.NoError(json.Unmarshal([]byte(albTestEvent), &event)) + event.IsBase64Encoded = true + event.Body = "SGVsbG8gZnJvbSBMYW1iZGEh" + encodedEvent, err := json.Marshal(event) + asrt.NoError(err) + + dump, err := dumpALB(encodedEvent, Options{}) + asrt.NoError(err) + expected := expectedALBDump + expected.Body = "Hello from Lambda!" + asrt.Equal(expected, dump) +} + +func TestALBURLEncoding(t *testing.T) { + asrt := assert.New(t) + + event := events.ALBTargetGroupRequest{} + asrt.NoError(json.Unmarshal([]byte(albTestEvent), &event)) + event.Path = "/%D0%BF%D1%80%D0%B8%D0%B2%D0%B5%D1%82" + event.MultiValueQueryStringParameters["parameter2"] = []string{"тест\""} + encodedEvent, err := json.Marshal(event) + asrt.NoError(err) + + dump, err := dumpALB(encodedEvent, Options{}) + asrt.NoError(err) + expected := expectedALBDump + expected.RequestURI = "/%D0%BF%D1%80%D0%B8%D0%B2%D0%B5%D1%82?myKey=val1&myKey=val2¶meter2=%D1%82%D0%B5%D1%81%D1%82%22" + expected.URL.Path = "/привет" + expected.Form = map[string][]string{ + "myKey": {"val1", "val2"}, + "parameter2": {"тест\""}, + } + asrt.Equal(expected, dump) +} + +func TestALBResponseHeaders(t *testing.T) { + asrt := assert.New(t) + + handler := func(w http.ResponseWriter, r *http.Request) { + header := w.Header() + header.Add("X-Foo", "1") + header.Add("X-Bar", "2") + header.Add("X-Bar", "3") + header.Add("Set-Cookie", "cookie1") + header.Add("Set-Cookie", "cookie2") + w.WriteHeader(404) + io.WriteString(w, "FOO") + } + lh := lambdaHandler{ + httpHandler: http.HandlerFunc(handler), + opts: &Options{}, + } + responseBytes, err := lh.Invoke(context.Background(), []byte(albTestEvent)) + asrt.NoError(err) + + var r events.ALBTargetGroupResponse + err = json.Unmarshal(responseBytes, &r) + asrt.NoError(err) + asrt.Equal(404, r.StatusCode) + asrt.Equal("FOO", r.Body) + expectedHeaders := map[string][]string{ + "X-Foo": {"1"}, + "X-Bar": {"2", "3"}, + "Set-Cookie": {"cookie1", "cookie2"}, + } + asrt.Equal(expectedHeaders, r.MultiValueHeaders) +} + +func TestALBBase64BodyResponseAll(t *testing.T) { + testBodyResponseAll(t, albTestEvent) +} + +func TestALBBase64BodyResponseNoMatch(t *testing.T) { + testBase64BodyResponseNoMatch(t, albTestEvent) +} + +func TestALBBase64BodyResponseMatch(t *testing.T) { + testBase64BodyResponseMatch(t, albTestEvent) +} diff --git a/apigatewayv2_test.go b/apigatewayv2_test.go deleted file mode 100644 index 5ea4dd1..0000000 --- a/apigatewayv2_test.go +++ /dev/null @@ -1,110 +0,0 @@ -package algnhsa - -import ( - "context" - "encoding/json" - "io" - "net/http" - "testing" - - "github.com/aws/aws-lambda-go/events" -) - -var testV2Request = `{ - "version": "2.0", - "routeKey": "$default", - "rawPath": "/my/path", - "rawQueryString": "parameter1=value1¶meter1=value2¶meter2=value", - "cookies": [ - "cookie1", - "cookie2" - ], - "headers": { - "header1": "value1", - "header2": "value1,value2" - }, - "queryStringParameters": { - "parameter1": "value1,value2", - "parameter2": "value" - }, - "requestContext": { - "accountId": "123456789012", - "apiId": "api-id", - "authentication": { - "clientCert": { - "clientCertPem": "CERT_CONTENT", - "subjectDN": "www.example.com", - "issuerDN": "Example issuer", - "serialNumber": "a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1", - "validity": { - "notBefore": "May 28 12:30:02 2019 GMT", - "notAfter": "Aug 5 09:36:04 2021 GMT" - } - } - }, - "authorizer": { - "jwt": { - "claims": { - "claim1": "value1", - "claim2": "value2" - }, - "scopes": [ - "scope1", - "scope2" - ] - } - }, - "domainName": "id.execute-api.us-east-1.amazonaws.com", - "domainPrefix": "id", - "http": { - "method": "POST", - "path": "/my/path", - "protocol": "HTTP/1.1", - "sourceIp": "IP", - "userAgent": "agent" - }, - "requestId": "id", - "routeKey": "$default", - "stage": "$default", - "time": "12/Mar/2020:19:03:58 +0000", - "timeEpoch": 1583348638390 - }, - "body": "Hello from Lambda", - "pathParameters": { - "parameter1": "value1" - }, - "isBase64Encoded": false, - "stageVariables": { - "stageVariable1": "value1", - "stageVariable2": "value2" - } -} -` - -func TestAPIGatewayV2(t *testing.T) { - h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.URL.Path != "/my/path" { - t.Errorf("expected path %q, got %q", "/my/path", r.URL.Path) - } - if r.Method != "POST" { - t.Errorf("expected method %q, got %q", "POST", r.Method) - } - io.WriteString(w, "test") - }) - lh := lambdaHandler{ - httpHandler: h, - opts: &Options{}, - } - responseBytes, err := lh.Invoke(context.Background(), []byte(testV2Request)) - if err != nil { - t.Errorf("failed to invoke handler: %v", err) - } - var r events.APIGatewayV2HTTPResponse - err = json.Unmarshal(responseBytes, &r) - if err != nil { - t.Errorf("failed to unmarshal response: %v", err) - } - if r.Body != "test" { - t.Errorf("unexpected body: %q", r.Body) - } -} diff --git a/apigw.go b/apigw.go deleted file mode 100644 index 8414aba..0000000 --- a/apigw.go +++ /dev/null @@ -1,43 +0,0 @@ -package algnhsa - -import ( - "context" - "encoding/json" - "errors" - "path" - - "github.com/aws/aws-lambda-go/events" -) - -var ( - errAPIGatewayUnexpectedRequest = errors.New("expected APIGatewayProxyRequest event") -) - -func newAPIGatewayRequest(ctx context.Context, payload []byte, opts *Options) (lambdaRequest, error) { - var event events.APIGatewayProxyRequest - if err := json.Unmarshal(payload, &event); err != nil { - return lambdaRequest{}, err - } - if event.RequestContext.AccountID == "" { - return lambdaRequest{}, errAPIGatewayUnexpectedRequest - } - - req := lambdaRequest{ - HTTPMethod: event.HTTPMethod, - Path: event.Path, - QueryStringParameters: event.QueryStringParameters, - MultiValueQueryStringParameters: event.MultiValueQueryStringParameters, - Headers: event.Headers, - MultiValueHeaders: event.MultiValueHeaders, - Body: event.Body, - IsBase64Encoded: event.IsBase64Encoded, - SourceIP: event.RequestContext.Identity.SourceIP, - Context: newProxyRequestContext(ctx, event), - } - - if opts.UseProxyPath { - req.Path = path.Join("/", event.PathParameters["proxy"]) - } - - return req, nil -} diff --git a/apigw_v1.go b/apigw_v1.go new file mode 100644 index 0000000..051836f --- /dev/null +++ b/apigw_v1.go @@ -0,0 +1,68 @@ +package algnhsa + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "path" + + "github.com/aws/aws-lambda-go/events" +) + +/* +AWS Documentation: + +- https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-develop-integrations-lambda.html +*/ + +var ( + errAPIGatewayV1UnexpectedRequest = errors.New("expected APIGatewayProxyRequest event") +) + +func newAPIGatewayV1Request(ctx context.Context, payload []byte, opts *Options) (lambdaRequest, error) { + var event events.APIGatewayProxyRequest + if err := json.Unmarshal(payload, &event); err != nil { + return lambdaRequest{}, err + } + if event.RequestContext.AccountID == "" { + return lambdaRequest{}, errAPIGatewayV1UnexpectedRequest + } + + req := lambdaRequest{ + HTTPMethod: event.HTTPMethod, + Path: event.Path, + QueryStringParameters: event.QueryStringParameters, + MultiValueQueryStringParameters: event.MultiValueQueryStringParameters, + Headers: event.Headers, + MultiValueHeaders: event.MultiValueHeaders, + Body: event.Body, + IsBase64Encoded: event.IsBase64Encoded, + SourceIP: event.RequestContext.Identity.SourceIP, + Context: context.WithValue(ctx, RequestTypeAPIGatewayV1, event), + requestType: RequestTypeAPIGatewayV1, + } + + if opts.UseProxyPath { + req.Path = path.Join("/", event.PathParameters["proxy"]) + } + + return req, nil +} + +func newAPIGatewayV1Response(r *http.Response) (lambdaResponse, error) { + resp := lambdaResponse{ + MultiValueHeaders: r.Header, + } + return resp, nil +} + +// APIGatewayV1RequestFromContext extracts the APIGatewayProxyRequest event from ctx. +func APIGatewayV1RequestFromContext(ctx context.Context) (events.APIGatewayProxyRequest, bool) { + val := ctx.Value(RequestTypeAPIGatewayV1) + if val == nil { + return events.APIGatewayProxyRequest{}, false + } + event, ok := val.(events.APIGatewayProxyRequest) + return event, ok +} diff --git a/apigw_v1_test.go b/apigw_v1_test.go new file mode 100644 index 0000000..2a1f603 --- /dev/null +++ b/apigw_v1_test.go @@ -0,0 +1,318 @@ +package algnhsa + +import ( + "context" + "encoding/json" + "errors" + "github.com/stretchr/testify/assert" + "io" + "net/http" + "testing" + + "github.com/aws/aws-lambda-go/events" +) + +var apiGatewayV1TestEvent = `{ + "version": "1.0", + "resource": "/my/path", + "path": "/my/path", + "httpMethod": "GET", + "headers": { + "header1": "value1", + "header2": "value2" + }, + "multiValueHeaders": { + "header1": [ + "value1" + ], + "header2": [ + "value1", + "value2" + ], + "cookie": [ + "cookie1", + "cookie2" + ] + }, + "queryStringParameters": { + "parameter1": "value1", + "parameter2": "value" + }, + "multiValueQueryStringParameters": { + "parameter1": [ + "value1", + "value2" + ], + "parameter2": [ + "value" + ] + }, + "requestContext": { + "accountId": "123456789012", + "apiId": "id", + "authorizer": { + "claims": null, + "scopes": null + }, + "domainName": "id.execute-api.us-east-1.amazonaws.com", + "domainPrefix": "id", + "extendedRequestId": "request-id", + "httpMethod": "GET", + "identity": { + "accessKey": null, + "accountId": null, + "caller": null, + "cognitoAuthenticationProvider": null, + "cognitoAuthenticationType": null, + "cognitoIdentityId": null, + "cognitoIdentityPoolId": null, + "principalOrgId": null, + "sourceIp": "192.0.2.1", + "user": null, + "userAgent": "user-agent", + "userArn": null, + "clientCert": { + "clientCertPem": "CERT_CONTENT", + "subjectDN": "www.example.com", + "issuerDN": "Example issuer", + "serialNumber": "a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1", + "validity": { + "notBefore": "May 28 12:30:02 2019 GMT", + "notAfter": "Aug 5 09:36:04 2021 GMT" + } + } + }, + "path": "/my/path", + "protocol": "HTTP/1.1", + "requestId": "id=", + "requestTime": "04/Mar/2020:19:15:17 +0000", + "requestTimeEpoch": 1583349317135, + "resourceId": null, + "resourcePath": "/my/path", + "stage": "$default" + }, + "pathParameters": {"proxy": "/my/path2"}, + "stageVariables": null, + "body": "Hello from Lambda!", + "isBase64Encoded": false +} +` + +var expectedApiGatewayV1Dump = RequestDebugDump{ + Method: "GET", + URL: struct { + Path string + RawPath string + }{ + Path: "/my/path", + RawPath: "", + }, + RequestURI: "/my/path?parameter1=value1¶meter1=value2¶meter2=value", + Host: "", + RemoteAddr: "192.0.2.1", + Header: map[string][]string{ + "Header1": {"value1"}, + "Header2": {"value1", "value2"}, + "Cookie": {"cookie1", "cookie2"}, + }, + Form: map[string][]string{ + "parameter1": {"value1", "value2"}, + "parameter2": {"value"}, + }, + Body: "Hello from Lambda!", +} + +func dumpAPIGatewayV1(payload []byte, opts Options) (RequestDebugDump, error) { + lh := lambdaHandler{ + httpHandler: http.HandlerFunc(RequestDebugDumpHandler), + opts: &opts, + } + responseBytes, err := lh.Invoke(context.Background(), payload) + if err != nil { + return RequestDebugDump{}, err + } + var r events.APIGatewayProxyResponse + if err := json.Unmarshal(responseBytes, &r); err != nil { + return RequestDebugDump{}, err + } + if r.StatusCode != 200 { + return RequestDebugDump{}, errors.New("expected status code 200") + } + var dump RequestDebugDump + if err := json.Unmarshal([]byte(r.Body), &dump); err != nil { + return RequestDebugDump{}, err + } + if dump.APIGatewayV1Request.RequestContext.HTTPMethod != "GET" { + return RequestDebugDump{}, errors.New("expected method GET") + } + dump.APIGatewayV1Request = nil + return dump, nil +} + +func TestAPIGatewayV1Base(t *testing.T) { + asrt := assert.New(t) + + dump, err := dumpAPIGatewayV1([]byte(apiGatewayV1TestEvent), Options{}) + asrt.NoError(err) + + asrt.Equal(expectedApiGatewayV1Dump, dump) +} + +func TestAPIGatewayV1ProxyPath(t *testing.T) { + asrt := assert.New(t) + + dump, err := dumpAPIGatewayV1([]byte(apiGatewayV1TestEvent), Options{UseProxyPath: true}) + asrt.NoError(err) + + expected := expectedApiGatewayV1Dump + expected.RequestURI = "/my/path2?parameter1=value1¶meter1=value2¶meter2=value" + expected.URL.Path = "/my/path2" + asrt.Equal(expected, dump) +} + +func TestAPIGatewayV1Base64BodyRequest(t *testing.T) { + asrt := assert.New(t) + + event := events.APIGatewayProxyRequest{} + asrt.NoError(json.Unmarshal([]byte(apiGatewayV1TestEvent), &event)) + event.IsBase64Encoded = true + event.Body = "SGVsbG8gZnJvbSBMYW1iZGEh" + encodedEvent, err := json.Marshal(event) + asrt.NoError(err) + + dump, err := dumpAPIGatewayV1(encodedEvent, Options{}) + asrt.NoError(err) + asrt.Equal(expectedApiGatewayV1Dump, dump) +} + +func TestAPIGatewayV1URLEncoding(t *testing.T) { + asrt := assert.New(t) + + event := events.APIGatewayProxyRequest{} + asrt.NoError(json.Unmarshal([]byte(apiGatewayV1TestEvent), &event)) + event.Path = "/%D0%BF%D1%80%D0%B8%D0%B2%D0%B5%D1%82" + event.MultiValueQueryStringParameters["parameter2"] = []string{"тест\""} + encodedEvent, err := json.Marshal(event) + asrt.NoError(err) + + dump, err := dumpAPIGatewayV1(encodedEvent, Options{}) + asrt.NoError(err) + expected := expectedApiGatewayV1Dump + expected.RequestURI = "/%D0%BF%D1%80%D0%B8%D0%B2%D0%B5%D1%82?parameter1=value1¶meter1=value2¶meter2=%D1%82%D0%B5%D1%81%D1%82%22" + expected.URL.Path = "/привет" + expected.Form = map[string][]string{ + "parameter1": {"value1", "value2"}, + "parameter2": {"тест\""}, + } + asrt.Equal(expected, dump) +} + +func TestAPIGatewayV1ResponseHeaders(t *testing.T) { + asrt := assert.New(t) + + handler := func(w http.ResponseWriter, r *http.Request) { + header := w.Header() + header.Add("X-Foo", "1") + header.Add("X-Bar", "2") + header.Add("X-Bar", "3") + header.Add("Set-Cookie", "cookie1") + header.Add("Set-Cookie", "cookie2") + w.WriteHeader(404) + io.WriteString(w, "FOO") + } + lh := lambdaHandler{ + httpHandler: http.HandlerFunc(handler), + opts: &Options{}, + } + responseBytes, err := lh.Invoke(context.Background(), []byte(apiGatewayV1TestEvent)) + asrt.NoError(err) + + var r events.APIGatewayProxyResponse + err = json.Unmarshal(responseBytes, &r) + asrt.NoError(err) + asrt.Equal(404, r.StatusCode) + asrt.Equal("FOO", r.Body) + expectedHeaders := map[string][]string{ + "X-Foo": {"1"}, + "X-Bar": {"2", "3"}, + "Set-Cookie": {"cookie1", "cookie2"}, + } + asrt.Equal(expectedHeaders, r.MultiValueHeaders) +} + +func testBodyResponseAll(t *testing.T, event string) { + t.Helper() + asrt := assert.New(t) + + handler := func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "Hello from Lambda!") + } + lh := lambdaHandler{ + httpHandler: http.HandlerFunc(handler), + opts: &Options{BinaryContentTypes: []string{"*/*"}}, + } + lh.opts.setBinaryContentTypeMap() + responseBytes, err := lh.Invoke(context.Background(), []byte(event)) + asrt.NoError(err) + + var r lambdaResponse + err = json.Unmarshal(responseBytes, &r) + asrt.NoError(err) + asrt.Equal(200, r.StatusCode) + asrt.Equal("SGVsbG8gZnJvbSBMYW1iZGEh", r.Body) +} + +func testBase64BodyResponseNoMatch(t *testing.T, event string) { + asrt := assert.New(t) + + handler := func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "Hello from Lambda!") + } + lh := lambdaHandler{ + httpHandler: http.HandlerFunc(handler), + opts: &Options{BinaryContentTypes: []string{"image/png"}}, + } + lh.opts.setBinaryContentTypeMap() + responseBytes, err := lh.Invoke(context.Background(), []byte(event)) + asrt.NoError(err) + + var r events.APIGatewayProxyResponse + err = json.Unmarshal(responseBytes, &r) + asrt.NoError(err) + asrt.Equal(200, r.StatusCode) + asrt.Equal("Hello from Lambda!", r.Body) +} + +func testBase64BodyResponseMatch(t *testing.T, event string) { + asrt := assert.New(t) + + handler := func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "image/png") + io.WriteString(w, "Hello from Lambda!") + } + lh := lambdaHandler{ + httpHandler: http.HandlerFunc(handler), + opts: &Options{BinaryContentTypes: []string{"image/png"}}, + } + lh.opts.setBinaryContentTypeMap() + responseBytes, err := lh.Invoke(context.Background(), []byte(event)) + asrt.NoError(err) + + var r events.APIGatewayProxyResponse + err = json.Unmarshal(responseBytes, &r) + asrt.NoError(err) + asrt.Equal(200, r.StatusCode) + asrt.Equal("SGVsbG8gZnJvbSBMYW1iZGEh", r.Body) +} + +func TestAPIGatewayV1Base64BodyResponseAll(t *testing.T) { + testBodyResponseAll(t, apiGatewayV1TestEvent) +} + +func TestAPIGatewayV1Base64BodyResponseNoMatch(t *testing.T) { + testBase64BodyResponseNoMatch(t, apiGatewayV1TestEvent) +} + +func TestAPIGatewayV1Base64BodyResponseMatch(t *testing.T) { + testBase64BodyResponseMatch(t, apiGatewayV1TestEvent) +} diff --git a/apigw_v2.go b/apigw_v2.go new file mode 100644 index 0000000..2da2b6c --- /dev/null +++ b/apigw_v2.go @@ -0,0 +1,90 @@ +package algnhsa + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "path" + "strings" + + "github.com/aws/aws-lambda-go/events" +) + +/* +AWS Documentation: + +- https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-develop-integrations-lambda.html +- https://docs.aws.amazon.com/lambda/latest/dg/lambda-urls.html +*/ + +var ( + errAPIGatewayV2UnexpectedRequest = errors.New("expected APIGatewayV2HTTPRequest event") +) + +func newAPIGatewayV2Request(ctx context.Context, payload []byte, opts *Options) (lambdaRequest, error) { + var event events.APIGatewayV2HTTPRequest + if err := json.Unmarshal(payload, &event); err != nil { + return lambdaRequest{}, err + } + if event.Version != "2.0" { + return lambdaRequest{}, errAPIGatewayV2UnexpectedRequest + } + + req := lambdaRequest{ + HTTPMethod: event.RequestContext.HTTP.Method, + Path: event.RawPath, + RawQueryString: event.RawQueryString, + Headers: event.Headers, + Body: event.Body, + IsBase64Encoded: event.IsBase64Encoded, + SourceIP: event.RequestContext.HTTP.SourceIP, + Context: context.WithValue(ctx, RequestTypeAPIGatewayV2, event), + requestType: RequestTypeAPIGatewayV2, + } + + // APIGatewayV2 doesn't support multi-value headers. + // For cookies there is a workaround - Cookie headers are assigned to the event Cookies slice. + // All other multi-value headers are joined into a single value with a comma. + // It would be unsafe to split such values on a comma - it's impossible to distinguish a multi-value header + // joined with a comma and a single-value header that contains a comma. + if len(event.Cookies) > 0 { + if req.MultiValueHeaders == nil { + req.MultiValueHeaders = make(map[string][]string) + } + req.MultiValueHeaders["Cookie"] = event.Cookies + } + + if opts.UseProxyPath { + req.Path = path.Join("/", event.PathParameters["proxy"]) + } + + return req, nil +} + +func newAPIGatewayV2Response(r *http.Response) (lambdaResponse, error) { + resp := lambdaResponse{ + Headers: make(map[string]string, len(r.Header)), + } + // APIGatewayV2 doesn't support multi-value headers. + for key, values := range r.Header { + // For cookies there is a workaround - Set-Cookie headers are assigned to the response Cookies slice. + if key == canonicalSetCookieHeaderKey { + resp.Cookies = values + continue + } + // All other multi-value headers are joined into a single value with a comma. + resp.Headers[key] = strings.Join(values, ",") + } + return resp, nil +} + +// APIGatewayV2RequestFromContext extracts the APIGatewayV2HTTPRequest event from ctx. +func APIGatewayV2RequestFromContext(ctx context.Context) (events.APIGatewayV2HTTPRequest, bool) { + val := ctx.Value(RequestTypeAPIGatewayV2) + if val == nil { + return events.APIGatewayV2HTTPRequest{}, false + } + event, ok := val.(events.APIGatewayV2HTTPRequest) + return event, ok +} diff --git a/apigw_v2_test.go b/apigw_v2_test.go new file mode 100644 index 0000000..fcf7066 --- /dev/null +++ b/apigw_v2_test.go @@ -0,0 +1,242 @@ +package algnhsa + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "github.com/stretchr/testify/assert" + "io" + "net/http" + "testing" + + "github.com/aws/aws-lambda-go/events" +) + +var apiGatewayV2TestEvent = `{ + "version": "2.0", + "routeKey": "$default", + "rawPath": "/my/path", + "rawQueryString": "parameter1=value1¶meter1=value2¶meter2=value", + "cookies": [ + "cookie1", + "cookie2" + ], + "headers": { + "header1": "value1", + "header2": "value1,value2" + }, + "queryStringParameters": { + "parameter1": "value1,value2", + "parameter2": "value" + }, + "requestContext": { + "accountId": "123456789012", + "apiId": "api-id", + "authentication": { + "clientCert": { + "clientCertPem": "CERT_CONTENT", + "subjectDN": "www.example.com", + "issuerDN": "Example issuer", + "serialNumber": "a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1:a1", + "validity": { + "notBefore": "May 28 12:30:02 2019 GMT", + "notAfter": "Aug 5 09:36:04 2021 GMT" + } + } + }, + "authorizer": { + "jwt": { + "claims": { + "claim1": "value1", + "claim2": "value2" + }, + "scopes": [ + "scope1", + "scope2" + ] + } + }, + "domainName": "id.execute-api.us-east-1.amazonaws.com", + "domainPrefix": "id", + "http": { + "method": "POST", + "path": "/my/path", + "protocol": "HTTP/1.1", + "sourceIp": "IP", + "userAgent": "agent" + }, + "requestId": "id", + "routeKey": "$default", + "stage": "$default", + "time": "12/Mar/2020:19:03:58 +0000", + "timeEpoch": 1583348638390 + }, + "body": "Hello from Lambda", + "pathParameters": { + "parameter1": "value1", + "proxy": "/my/path2" + }, + "isBase64Encoded": false, + "stageVariables": { + "stageVariable1": "value1", + "stageVariable2": "value2" + } +} +` + +var expectedApiGatewayV2Dump = RequestDebugDump{ + Method: "POST", + URL: struct { + Path string + RawPath string + }{ + Path: "/my/path", + RawPath: "", + }, + RequestURI: "/my/path?parameter1=value1¶meter1=value2¶meter2=value", + Host: "", + RemoteAddr: "IP", + Header: map[string][]string{ + "Header1": {"value1"}, + "Header2": {"value1,value2"}, + "Cookie": {"cookie1", "cookie2"}, + }, + Form: map[string][]string{ + "parameter1": {"value1", "value2"}, + "parameter2": {"value"}, + }, + Body: "Hello from Lambda", +} + +func dumpAPIGatewayV2(payload []byte, opts Options) (RequestDebugDump, error) { + lh := lambdaHandler{ + httpHandler: http.HandlerFunc(RequestDebugDumpHandler), + opts: &opts, + } + responseBytes, err := lh.Invoke(context.Background(), payload) + if err != nil { + return RequestDebugDump{}, err + } + var r events.APIGatewayV2HTTPResponse + if err := json.Unmarshal(responseBytes, &r); err != nil { + return RequestDebugDump{}, err + } + if r.StatusCode != 200 { + return RequestDebugDump{}, errors.New("expected status code 200") + } + var dump RequestDebugDump + if err := json.Unmarshal([]byte(r.Body), &dump); err != nil { + return RequestDebugDump{}, err + } + if dump.APIGatewayV2Request.RequestContext.HTTP.Method != "POST" { + fmt.Printf("%+v\n", dump) + return RequestDebugDump{}, errors.New("expected method POST") + } + dump.APIGatewayV2Request = nil + return dump, nil +} + +func TestAPIGatewayV2Base(t *testing.T) { + asrt := assert.New(t) + + dump, err := dumpAPIGatewayV2([]byte(apiGatewayV2TestEvent), Options{}) + asrt.NoError(err) + + asrt.Equal(expectedApiGatewayV2Dump, dump) +} + +func TestAPIGatewayV2ProxyPath(t *testing.T) { + asrt := assert.New(t) + + dump, err := dumpAPIGatewayV2([]byte(apiGatewayV2TestEvent), Options{UseProxyPath: true}) + asrt.NoError(err) + + expected := expectedApiGatewayV2Dump + expected.RequestURI = "/my/path2?parameter1=value1¶meter1=value2¶meter2=value" + expected.URL.Path = "/my/path2" + asrt.Equal(expected, dump) +} + +func TestAPIGatewayV2Base64BodyRequest(t *testing.T) { + asrt := assert.New(t) + + event := events.APIGatewayV2HTTPRequest{} + asrt.NoError(json.Unmarshal([]byte(apiGatewayV2TestEvent), &event)) + event.IsBase64Encoded = true + event.Body = "SGVsbG8gZnJvbSBMYW1iZGE=" + encodedEvent, err := json.Marshal(event) + asrt.NoError(err) + + dump, err := dumpAPIGatewayV2(encodedEvent, Options{}) + asrt.NoError(err) + asrt.Equal(expectedApiGatewayV2Dump, dump) +} + +func TestAPIGatewayV2URLEncoding(t *testing.T) { + asrt := assert.New(t) + + event := events.APIGatewayV2HTTPRequest{} + asrt.NoError(json.Unmarshal([]byte(apiGatewayV2TestEvent), &event)) + event.RawPath = "/%D0%BF%D1%80%D0%B8%D0%B2%D0%B5%D1%82" + event.RawQueryString = "parameter1=value1¶meter1=value2¶meter2=%D1%82%D0%B5%D1%81%D1%82%22" + event.QueryStringParameters["parameter2"] = "тест\"" + encodedEvent, err := json.Marshal(event) + asrt.NoError(err) + + dump, err := dumpAPIGatewayV2(encodedEvent, Options{}) + asrt.NoError(err) + expected := expectedApiGatewayV2Dump + expected.RequestURI = "/%D0%BF%D1%80%D0%B8%D0%B2%D0%B5%D1%82?parameter1=value1¶meter1=value2¶meter2=%D1%82%D0%B5%D1%81%D1%82%22" + expected.URL.Path = "/привет" + expected.Form = map[string][]string{ + "parameter1": {"value1", "value2"}, + "parameter2": {"тест\""}, + } + asrt.Equal(expected, dump) +} + +func TestAPIGatewayV2ResponseHeaders(t *testing.T) { + asrt := assert.New(t) + + handler := func(w http.ResponseWriter, r *http.Request) { + header := w.Header() + header.Add("X-Foo", "1") + header.Add("X-Bar", "2") + header.Add("X-Bar", "3") + header.Add("Set-Cookie", "cookie1") + header.Add("Set-Cookie", "cookie2") + w.WriteHeader(404) + io.WriteString(w, "FOO") + } + lh := lambdaHandler{ + httpHandler: http.HandlerFunc(handler), + opts: &Options{}, + } + responseBytes, err := lh.Invoke(context.Background(), []byte(apiGatewayV2TestEvent)) + asrt.NoError(err) + + var r events.APIGatewayV2HTTPResponse + err = json.Unmarshal(responseBytes, &r) + asrt.NoError(err) + asrt.Equal(404, r.StatusCode) + asrt.Equal("FOO", r.Body) + expectedHeaders := map[string]string{ + "X-Foo": "1", + "X-Bar": "2,3", + } + asrt.Equal(expectedHeaders, r.Headers) + asrt.Equal([]string{"cookie1", "cookie2"}, r.Cookies) +} + +func TestAPIGatewayV2Base64BodyResponseAll(t *testing.T) { + testBodyResponseAll(t, apiGatewayV2TestEvent) +} + +func TestAPIGatewayV2Base64BodyResponseNoMatch(t *testing.T) { + testBase64BodyResponseNoMatch(t, apiGatewayV2TestEvent) +} + +func TestAPIGatewayV2Base64BodyResponseMatch(t *testing.T) { + testBase64BodyResponseMatch(t, apiGatewayV2TestEvent) +} diff --git a/context.go b/context.go deleted file mode 100644 index 4faacd7..0000000 --- a/context.go +++ /dev/null @@ -1,57 +0,0 @@ -package algnhsa - -import ( - "context" - - "github.com/aws/aws-lambda-go/events" -) - -type contextKey int - -const ( - proxyRequestContextKey contextKey = iota - apiGatewayV2HTTPRequestContextKey - albRequestContextKey -) - -func newProxyRequestContext(ctx context.Context, event events.APIGatewayProxyRequest) context.Context { - return context.WithValue(ctx, proxyRequestContextKey, event) -} - -// ProxyRequestFromContext extracts the APIGatewayProxyRequest event from ctx. -func ProxyRequestFromContext(ctx context.Context) (events.APIGatewayProxyRequest, bool) { - val := ctx.Value(proxyRequestContextKey) - if val == nil { - return events.APIGatewayProxyRequest{}, false - } - event, ok := val.(events.APIGatewayProxyRequest) - return event, ok -} - -func newAPIGatewayV2HTTPRequestContext(ctx context.Context, event events.APIGatewayV2HTTPRequest) context.Context { - return context.WithValue(ctx, apiGatewayV2HTTPRequestContextKey, event) -} - -// APIGatewayV2HTTPRequestFromContext extracts the APIGatewayV2HTTPRequest event from ctx. -func APIGatewayV2HTTPRequestFromContext(ctx context.Context) (events.APIGatewayV2HTTPRequest, bool) { - val := ctx.Value(apiGatewayV2HTTPRequestContextKey) - if val == nil { - return events.APIGatewayV2HTTPRequest{}, false - } - event, ok := val.(events.APIGatewayV2HTTPRequest) - return event, ok -} - -func newTargetGroupRequestContext(ctx context.Context, event events.ALBTargetGroupRequest) context.Context { - return context.WithValue(ctx, albRequestContextKey, event) -} - -// TargetGroupRequestFromContext extracts the ALBTargetGroupRequest event from ctx. -func TargetGroupRequestFromContext(ctx context.Context) (events.ALBTargetGroupRequest, bool) { - val := ctx.Value(albRequestContextKey) - if val == nil { - return events.ALBTargetGroupRequest{}, false - } - event, ok := val.(events.ALBTargetGroupRequest) - return event, ok -} diff --git a/debug.go b/debug.go new file mode 100644 index 0000000..e28eb2a --- /dev/null +++ b/debug.go @@ -0,0 +1,103 @@ +package algnhsa + +import ( + "encoding/json" + "github.com/aws/aws-lambda-go/events" + "io" + "mime" + "net/http" +) + +const maxDumpFormParseMem = 32 << 20 // 32MB + +// RequestDebugDump is a dump of the HTTP request including the original Lambda event. +type RequestDebugDump struct { + Method string + URL struct { + Path string + RawPath string + } + RequestURI string + Host string + RemoteAddr string + Header map[string][]string + Form map[string][]string + Body string + APIGatewayV1Request *events.APIGatewayProxyRequest `json:",omitempty"` + APIGatewayV2Request *events.APIGatewayV2HTTPRequest `json:",omitempty"` + ALBRequest *events.ALBTargetGroupRequest `json:",omitempty"` +} + +func parseMediaType(r *http.Request) (string, error) { + ct := r.Header.Get("Content-Type") + if ct == "" { + return "", nil + } + mt, _, err := mime.ParseMediaType(ct) + return mt, err +} + +// NewRequestDebugDump creates a new RequestDebugDump from an HTTP request. +func NewRequestDebugDump(r *http.Request) (*RequestDebugDump, error) { + mt, err := parseMediaType(r) + if err != nil { + return nil, err + } + if mt == "multipart/form-data" { + if err := r.ParseMultipartForm(maxDumpFormParseMem); err != nil { + return nil, err + } + } else { + if err := r.ParseForm(); err != nil { + return nil, err + } + } + + body, err := io.ReadAll(r.Body) + if err != nil { + return nil, err + } + + dump := &RequestDebugDump{ + Method: r.Method, + URL: struct { + Path string + RawPath string + }{Path: r.URL.Path, RawPath: r.URL.RawPath}, + RequestURI: r.RequestURI, + Host: r.Host, + RemoteAddr: r.RemoteAddr, + Header: r.Header, + Form: r.Form, + Body: string(body), + } + + if event, ok := APIGatewayV1RequestFromContext(r.Context()); ok { + dump.APIGatewayV1Request = &event + } + if event, ok := APIGatewayV2RequestFromContext(r.Context()); ok { + dump.APIGatewayV2Request = &event + } + if event, ok := ALBRequestFromContext(r.Context()); ok { + dump.ALBRequest = &event + } + + return dump, nil +} + +// RequestDebugDumpHandler is an HTTP handler that returns JSON encoded RequestDebugDump. +func RequestDebugDumpHandler(w http.ResponseWriter, r *http.Request) { + dump, err := NewRequestDebugDump(r) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + _, _ = io.WriteString(w, err.Error()) + return + } + enc := json.NewEncoder(w) + if err := enc.Encode(dump); err != nil { + w.WriteHeader(http.StatusInternalServerError) + _, _ = io.WriteString(w, err.Error()) + return + } + return +} diff --git a/example_test.go b/example_test.go index 5e2cd96..691da3c 100644 --- a/example_test.go +++ b/example_test.go @@ -8,10 +8,6 @@ import ( "github.com/akrylysov/algnhsa" ) -func indexHandler(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("index")) -} - func addHandler(w http.ResponseWriter, r *http.Request) { f, _ := strconv.Atoi(r.FormValue("first")) s, _ := strconv.Atoi(r.FormValue("second")) @@ -20,14 +16,13 @@ func addHandler(w http.ResponseWriter, r *http.Request) { } func contextHandler(w http.ResponseWriter, r *http.Request) { - proxyReq, ok := algnhsa.ProxyRequestFromContext(r.Context()) + lambdaEvent, ok := algnhsa.APIGatewayV2RequestFromContext(r.Context()) if ok { - fmt.Fprint(w, proxyReq.RequestContext.AccountID) + fmt.Fprint(w, lambdaEvent.RequestContext.AccountID) } } -func Example() { - http.HandleFunc("/", indexHandler) +func main() { http.HandleFunc("/add", addHandler) http.HandleFunc("/context", contextHandler) algnhsa.ListenAndServe(http.DefaultServeMux, nil) diff --git a/go.mod b/go.mod index 1c2e90d..c3d0b7a 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,14 @@ module github.com/akrylysov/algnhsa -go 1.12 +go 1.18 require ( - github.com/aws/aws-lambda-go v1.27.0 - github.com/stretchr/testify v1.6.1 + github.com/aws/aws-lambda-go v1.37.0 + github.com/stretchr/testify v1.7.2 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index b4e7ed6..e7403f1 100644 --- a/go.sum +++ b/go.sum @@ -1,22 +1,14 @@ -github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= -github.com/aws/aws-lambda-go v1.27.0 h1:aLzrJwdyHoF1A18YeVdJjX8Ixkd+bpogdxVInvHcWjM= -github.com/aws/aws-lambda-go v1.27.0/go.mod h1:jJmlefzPfGnckuHdXX7/80O3BvUUi12XOkbv4w9SGLU= -github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= -github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= +github.com/aws/aws-lambda-go v1.37.0 h1:WXkQ/xhIcXZZ2P5ZBEw+bbAKeCEcb5NtiYpSwVVzIXg= +github.com/aws/aws-lambda-go v1.37.0/go.mod h1:jwFe2KmMsHmffA1X2R09hH6lFzJQxzI8qK17ewzbQMM= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/russross/blackfriday/v2 v2.0.1/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= -github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/urfave/cli/v2 v2.2.0/go.mod h1:SE9GqnLQmjVa0iPEY0f1w3ygNIYcIJ0OKPMoW2caLfQ= +github.com/stretchr/testify v1.7.2 h1:4jaiDzPyXQvSd7D0EjG45355tLlV3VOECpq10pLC+8s= +github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776 h1:tQIYjPdBoyREyB9XMu+nnTclpTYkz2zFM+lzLJFO4gQ= -gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/httpapi.go b/httpapi.go deleted file mode 100644 index b7bc818..0000000 --- a/httpapi.go +++ /dev/null @@ -1,41 +0,0 @@ -package algnhsa - -import ( - "context" - "encoding/json" - "errors" - "path" - - "github.com/aws/aws-lambda-go/events" -) - -var ( - errAPIGatewayV2UnexpectedRequest = errors.New("expected APIGatewayV2HTTPRequest event") -) - -func newAPIGatewayV2HTTPRequest(ctx context.Context, payload []byte, opts *Options) (lambdaRequest, error) { - var event events.APIGatewayV2HTTPRequest - if err := json.Unmarshal(payload, &event); err != nil { - return lambdaRequest{}, err - } - if event.Version != "2.0" { - return lambdaRequest{}, errAPIGatewayV2UnexpectedRequest - } - - req := lambdaRequest{ - HTTPMethod: event.RequestContext.HTTP.Method, - Path: event.RequestContext.HTTP.Path, - QueryStringParameters: event.QueryStringParameters, - Headers: event.Headers, - Body: event.Body, - IsBase64Encoded: event.IsBase64Encoded, - SourceIP: event.RequestContext.HTTP.SourceIP, - Context: newAPIGatewayV2HTTPRequestContext(ctx, event), - } - - if opts.UseProxyPath { - req.Path = path.Join("/", event.PathParameters["proxy"]) - } - - return req, nil -} diff --git a/options.go b/options.go index 355be05..39a18e6 100644 --- a/options.go +++ b/options.go @@ -4,7 +4,7 @@ type RequestType int const ( RequestTypeAuto RequestType = iota - RequestTypeAPIGateway + RequestTypeAPIGatewayV1 RequestTypeAPIGatewayV2 RequestTypeALB ) @@ -12,7 +12,7 @@ const ( // Options holds the optional parameters. type Options struct { // RequestType sets the expected request type. - // By default algnhsa deduces the request type from the lambda function payload. + // By default, algnhsa deduces the request type from the lambda function payload. RequestType RequestType // BinaryContentTypes sets content types that should be treated as binary types. @@ -20,9 +20,12 @@ type Options struct { BinaryContentTypes []string binaryContentTypeMap map[string]bool - // Use API Gateway PathParameters["proxy"] when constructing the request url. + // Use API Gateway V1 PathParameters["proxy"] when constructing the request url. // Strips the base path mapping when using a custom domain with API Gateway. UseProxyPath bool + + // DebugLog enables printing request and response objects to stdout. + DebugLog bool } func (opts *Options) setBinaryContentTypeMap() { diff --git a/request.go b/request.go index 8d62ce2..89e94dc 100644 --- a/request.go +++ b/request.go @@ -10,32 +10,36 @@ import ( "strings" ) +var errUnsupportedPayloadFormat = errors.New("unsupported payload format; supported formats: APIGatewayV2HTTPRequest, APIGatewayProxyRequest, ALBTargetGroupRequest") + type lambdaRequest struct { - HTTPMethod string `json:"httpMethod"` - Path string `json:"path"` - QueryStringParameters map[string]string `json:"queryStringParameters,omitempty"` - MultiValueQueryStringParameters map[string][]string `json:"multiValueQueryStringParameters,omitempty"` - Headers map[string]string `json:"headers,omitempty"` - MultiValueHeaders map[string][]string `json:"multiValueHeaders,omitempty"` - IsBase64Encoded bool `json:"isBase64Encoded"` - Body string `json:"body"` + HTTPMethod string + Path string + QueryStringParameters map[string]string + MultiValueQueryStringParameters map[string][]string + RawQueryString string + Headers map[string]string + MultiValueHeaders map[string][]string + IsBase64Encoded bool + Body string SourceIP string Context context.Context + requestType RequestType } func newLambdaRequest(ctx context.Context, payload []byte, opts *Options) (lambdaRequest, error) { switch opts.RequestType { - case RequestTypeAPIGateway: - return newAPIGatewayRequest(ctx, payload, opts) + case RequestTypeAPIGatewayV1: + return newAPIGatewayV1Request(ctx, payload, opts) case RequestTypeAPIGatewayV2: - return newAPIGatewayV2HTTPRequest(ctx, payload, opts) + return newAPIGatewayV2Request(ctx, payload, opts) case RequestTypeALB: return newALBRequest(ctx, payload, opts) } // The request type wasn't specified. - // Try to decode the payload as APIGatewayV2HTTPRequest, fails back to APIGatewayProxyRequest, then ALBTargetGroupRequest. - req, err := newAPIGatewayV2HTTPRequest(ctx, payload, opts) + // Try to decode the payload as APIGatewayV2HTTPRequest, fall back to APIGatewayProxyRequest, then ALBTargetGroupRequest. + req, err := newAPIGatewayV2Request(ctx, payload, opts) if err != nil && err != errAPIGatewayV2UnexpectedRequest { return lambdaRequest{}, err } @@ -43,8 +47,8 @@ func newLambdaRequest(ctx context.Context, payload []byte, opts *Options) (lambd return req, nil } - req, err = newAPIGatewayRequest(ctx, payload, opts) - if err != nil && err != errAPIGatewayUnexpectedRequest { + req, err = newAPIGatewayV1Request(ctx, payload, opts) + if err != nil && err != errAPIGatewayV1UnexpectedRequest { return lambdaRequest{}, err } if err == nil { @@ -59,24 +63,27 @@ func newLambdaRequest(ctx context.Context, payload []byte, opts *Options) (lambd return req, nil } - return lambdaRequest{}, errors.New("neither APIGatewayProxyRequest nor ALBTargetGroupRequest received") + return lambdaRequest{}, errUnsupportedPayloadFormat } func newHTTPRequest(event lambdaRequest) (*http.Request, error) { // Build request URL. - params := url.Values{} - for k, v := range event.QueryStringParameters { - params.Set(k, v) - } - for k, vals := range event.MultiValueQueryStringParameters { - params[k] = vals + rawQuery := event.RawQueryString + if len(rawQuery) == 0 { + params := url.Values{} + for k, v := range event.QueryStringParameters { + params.Set(k, v) + } + for k, vals := range event.MultiValueQueryStringParameters { + params[k] = vals + } + rawQuery = params.Encode() } - // Set headers. // https://docs.aws.amazon.com/apigateway/latest/developerguide/set-up-lambda-proxy-integrations.html - // If you specify values for both headers and multiValueHeaders, API Gateway merges them into a single list. + // If you specify values for both headers and multiValueHeaders, API Gateway V1 merges them into a single list. // If the same key-value pair is specified in both, only the values from multiValueHeaders will appear - // the merged list. + // in the merged list. headers := make(http.Header) for k, v := range event.Headers { headers.Set(k, v) @@ -85,21 +92,14 @@ func newHTTPRequest(event lambdaRequest) (*http.Request, error) { headers[http.CanonicalHeaderKey(k)] = vals } - u := url.URL{ - Host: headers.Get("host"), - RawPath: event.Path, - RawQuery: params.Encode(), - } - - // Unescape request path - p, err := url.PathUnescape(u.RawPath) + unescapedPath, err := url.PathUnescape(event.Path) if err != nil { return nil, err } - u.Path = p - - if u.Path == u.RawPath { - u.RawPath = "" + u := url.URL{ + Host: headers.Get("Host"), + Path: unescapedPath, + RawQuery: rawQuery, } // Handle base64 encoded body. @@ -109,7 +109,7 @@ func newHTTPRequest(event lambdaRequest) (*http.Request, error) { } // Create a new request. - r, err := http.NewRequest(event.HTTPMethod, u.String(), body) + r, err := http.NewRequestWithContext(event.Context, event.HTTPMethod, u.String(), body) if err != nil { return nil, err } @@ -122,5 +122,5 @@ func newHTTPRequest(event lambdaRequest) (*http.Request, error) { r.Header = headers - return r.WithContext(event.Context), nil + return r, nil } diff --git a/response.go b/response.go index c08c3c3..e7b8500 100644 --- a/response.go +++ b/response.go @@ -2,36 +2,52 @@ package algnhsa import ( "encoding/base64" + "net/http" "net/http/httptest" ) const acceptAllContentType = "*/*" +var canonicalSetCookieHeaderKey = http.CanonicalHeaderKey("Set-Cookie") + +// lambdaResponse is a combined lambda response. +// It contains common fields from APIGatewayProxyResponse, APIGatewayV2HTTPResponse and ALBTargetGroupResponse. type lambdaResponse struct { StatusCode int `json:"statusCode"` - Headers map[string]string `json:"headers"` - MultiValueHeaders map[string][]string `json:"multiValueHeaders"` + Headers map[string]string `json:"headers,omitempty"` + MultiValueHeaders map[string][]string `json:"multiValueHeaders,omitempty"` + Cookies []string `json:"cookies,omitempty"` Body string `json:"body"` IsBase64Encoded bool `json:"isBase64Encoded,omitempty"` } -func newLambdaResponse(w *httptest.ResponseRecorder, binaryContentTypes map[string]bool) (lambdaResponse, error) { - event := lambdaResponse{} - - // Set status code. - event.StatusCode = w.Code +func newLambdaResponse(w *httptest.ResponseRecorder, binaryContentTypes map[string]bool, requestType RequestType) (lambdaResponse, error) { + result := w.Result() + + var resp lambdaResponse + var err error + switch requestType { + case RequestTypeAPIGatewayV1: + resp, err = newAPIGatewayV1Response(result) + case RequestTypeALB: + resp, err = newALBResponse(result) + case RequestTypeAPIGatewayV2: + resp, err = newAPIGatewayV2Response(result) + } + if err != nil { + return resp, err + } - // Set headers. - event.MultiValueHeaders = w.Result().Header + resp.StatusCode = result.StatusCode // Set body. - contentType := w.Header().Get("Content-Type") + contentType := result.Header.Get("Content-Type") if binaryContentTypes[acceptAllContentType] || binaryContentTypes[contentType] { - event.Body = base64.StdEncoding.EncodeToString(w.Body.Bytes()) - event.IsBase64Encoded = true + resp.Body = base64.StdEncoding.EncodeToString(w.Body.Bytes()) + resp.IsBase64Encoded = true } else { - event.Body = w.Body.String() + resp.Body = w.Body.String() } - return event, nil + return resp, nil }