diff --git a/src/caduceus/caduceus_test.go b/src/caduceus/caduceus_test.go index 32c3f87e..9216eb69 100644 --- a/src/caduceus/caduceus_test.go +++ b/src/caduceus/caduceus_test.go @@ -138,6 +138,7 @@ func TestServerHandler(t *testing.T) { } req := httptest.NewRequest("POST", "localhost:8080", strings.NewReader("Test payload.")) + badReq := httptest.NewRequest("GET", "localhost:8080", strings.NewReader("Test payload.")) t.Run("TestServeHTTPHappyPath", func(t *testing.T) { req.Header.Set("Content-Type", "application/json") @@ -151,6 +152,18 @@ func TestServerHandler(t *testing.T) { fakeHealth.AssertExpectations(t) }) + t.Run("TestServeHTTPBadMethod", func(t *testing.T) { + badReq.Header.Set("Content-Type", "application/json") + + w := httptest.NewRecorder() + serverWrapper.ServeHTTP(w, badReq) + resp := w.Result() + + assert.Equal(400, resp.StatusCode) + fakeHandler.AssertExpectations(t) + fakeHealth.AssertExpectations(t) + }) + t.Run("TestServeHTTPTooManyHeaders", func(t *testing.T) { req.Header.Add("Content-Type", "too/many/headers") diff --git a/src/caduceus/http.go b/src/caduceus/http.go index 0172061f..ed3f66f7 100644 --- a/src/caduceus/http.go +++ b/src/caduceus/http.go @@ -2,6 +2,7 @@ package main import ( "encoding/json" + "fmt" "github.com/Comcast/webpa-common/logging" "io/ioutil" "net/http" @@ -21,17 +22,22 @@ type ServerHandler struct { func (sh *ServerHandler) ServeHTTP(response http.ResponseWriter, request *http.Request) { defer request.Body.Close() - sh.Info("Receiving incoming post...") + sh.Info("Receiving incoming request...") timeStamps := CaduceusTimestamps{ TimeReceived: time.Now(), } + if request.Method != "POST" { + response.WriteHeader(http.StatusBadRequest) + response.Write([]byte(fmt.Sprintf("Unsupported method \"%s\"... Caduceus only supports \"POST\" method.\n", request.Method))) + return + } + myPayload, err := ioutil.ReadAll(request.Body) if err != nil { - statusMsg := "Unable to retrieve the request body: " + err.Error() + ".\n" response.WriteHeader(http.StatusBadRequest) - response.Write([]byte(statusMsg)) + response.Write([]byte(fmt.Sprintf("Unable to retrieve the request body: %s.\n", err.Error))) return }