From 8c22cf4e9669c8ebb5a71ea72ba506e5112c885d Mon Sep 17 00:00:00 2001 From: "Holaday, Sean" Date: Wed, 12 Apr 2017 16:53:01 -0700 Subject: [PATCH 1/2] Updating code so that caduceus only supports POST method requests. --- src/caduceus/http.go | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/caduceus/http.go b/src/caduceus/http.go index 0172061f..ed3f66f7 100644 --- a/src/caduceus/http.go +++ b/src/caduceus/http.go @@ -2,6 +2,7 @@ package main import ( "encoding/json" + "fmt" "github.com/Comcast/webpa-common/logging" "io/ioutil" "net/http" @@ -21,17 +22,22 @@ type ServerHandler struct { func (sh *ServerHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { defer request.Body.Close() - sh.Info("Receiving incoming post...") + sh.Info("Receiving incoming request...") timeStamps := CaduceusTimestamps{ TimeReceived: time.Now(), } + if request.Method != "POST" { + response.WriteHeader(http.StatusBadRequest) + response.Write([]byte(fmt.Sprintf("Unsupported method \"%s\"... Caduceus only supports \"POST\" method.\n", request.Method))) + return + } + myPayload, err := ioutil.ReadAll(request.Body) if err != nil { - statusMsg := "Unable to retrieve the request body: " + err.Error() + ".\n" response.WriteHeader(http.StatusBadRequest) - response.Write([]byte(statusMsg)) + response.Write([]byte(fmt.Sprintf("Unable to retrieve the request body: %s.\n", err.Error))) return } From 4a179f580f9f42a3ccdd7e736ee631ca997bb24b Mon Sep 17 00:00:00 2001 From: "Holaday, Sean" Date: Wed, 12 Apr 2017 17:02:03 -0700 Subject: [PATCH 2/2] Added in a test case for when `caduceus` is queried with a bad method (a.k.a. not POST) --- src/caduceus/caduceus_test.go | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/caduceus/caduceus_test.go b/src/caduceus/caduceus_test.go index 32c3f87e..9216eb69 100644 --- a/src/caduceus/caduceus_test.go +++ b/src/caduceus/caduceus_test.go @@ -138,6 +138,7 @@ func TestServerHandler(t *testing.T) { } req := httptest.NewRequest("POST", "localhost:8080", strings.NewReader("Test payload.")) + badReq := httptest.NewRequest("GET", "localhost:8080", strings.NewReader("Test payload.")) t.Run("TestServeHTTPHappyPath", func(t *testing.T) { req.Header.Set("Content-Type", "application/json") @@ -151,6 +152,18 @@ func TestServerHandler(t *testing.T) { fakeHealth.AssertExpectations(t) }) + t.Run("TestServeHTTPBadMethod", func(t *testing.T) { + badReq.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + serverWrapper.ServeHTTP(w, badReq) + resp := w.Result() + + assert.Equal(400, resp.StatusCode) + fakeHandler.AssertExpectations(t) + fakeHealth.AssertExpectations(t) + }) + t.Run("TestServeHTTPTooManyHeaders", func(t *testing.T) { req.Header.Add("Content-Type", "too/many/headers")