diff --git a/graphql/handler/testserver/testserver.go b/graphql/handler/testserver/testserver.go index c648c7ed8f9..94d8a6fe5a3 100644 --- a/graphql/handler/testserver/testserver.go +++ b/graphql/handler/testserver/testserver.go @@ -44,7 +44,7 @@ func New() *TestServer { switch opCtx.Operation.Operation { case ast.Query: ran := false - // If the query contains @defer, we will mimick a deferred response. + // If the query contains @defer, we will mimic a deferred response. if strings.Contains(opCtx.RawQuery, "@defer") { initialResponse := true return func(context context.Context) *graphql.Response { diff --git a/graphql/handler/transport/http_multipart_mixed.go b/graphql/handler/transport/http_multipart_mixed.go index bd0f904cf0d..6447b286043 100644 --- a/graphql/handler/transport/http_multipart_mixed.go +++ b/graphql/handler/transport/http_multipart_mixed.go @@ -256,6 +256,13 @@ func (a *multipartResponseAggregator) flush(w http.ResponseWriter) { writeJson(w, a.initialResponse) hasNext = a.initialResponse.HasNext != nil && *a.initialResponse.HasNext + + // Handle when initial is aggregated with deferred responses. + if len(a.deferResponses) > 0 { + fmt.Fprintf(w, "\r\n") + writeBoundary(w, a.boundary, false) + } + // Reset the initial response so we don't send it again a.initialResponse = nil } diff --git a/graphql/handler/transport/http_multipart_mixed_test.go b/graphql/handler/transport/http_multipart_mixed_test.go index 090a3e96bf6..249591794da 100644 --- a/graphql/handler/transport/http_multipart_mixed_test.go +++ b/graphql/handler/transport/http_multipart_mixed_test.go @@ -8,6 +8,7 @@ import ( "strings" "sync" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -102,7 +103,7 @@ func TestMultipartMixed(t *testing.T) { return bs } - t.Run("initial and incremental patches", func(t *testing.T) { + t.Run("initial and incremental patches un-aggregated", func(t *testing.T) { handler, srv := initializeWithServer() defer srv.Close() @@ -167,4 +168,69 @@ func TestMultipartMixed(t *testing.T) { wg.Wait() }) + + t.Run("initial and incremental patches aggregated", func(t *testing.T) { + handler := testserver.New() + handler.AddTransport(transport.MultipartMixed{ + Boundary: "graphql", + DeliveryTimeout: time.Hour, + }) + + srv := httptest.NewServer(handler) + defer srv.Close() + + var err error + var res *http.Response + + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + client := &http.Client{} + req := createHTTPRequest( + srv.URL, + `{"query":"query { ... @defer { name } }"}`, + ) + res, err = client.Do(req) + }() + + handler.SendNextSubscriptionMessage() + handler.SendNextSubscriptionMessage() + handler.SendCompleteSubscriptionMessage() + wg.Wait() + + require.NoError(t, err, "Request threw error -> %s", err) + defer func() { + require.NoError(t, res.Body.Close()) + }() + + assert.Equal(t, 200, res.StatusCode, "Request return wrong status -> %d", res.Status) + assert.Equal(t, "keep-alive", res.Header.Get("Connection")) + assert.Contains(t, res.Header.Get("Content-Type"), "multipart/mixed") + assert.Contains(t, res.Header.Get("Content-Type"), `boundary="graphql"`) + + br := bufio.NewReader(res.Body) + assert.Equal(t, "--graphql\r\n", readLine(br)) + assert.Equal(t, "Content-Type: application/json\r\n", readLine(br)) + assert.Equal(t, "\r\n", readLine(br)) + assert.Equal(t, + "{\"data\":{\"name\":null},\"hasNext\":true}\r\n", + readLine(br), + ) + + assert.Equal(t, "--graphql\r\n", readLine(br)) + assert.Equal(t, "Content-Type: application/json\r\n", readLine(br)) + assert.Equal(t, "\r\n", readLine(br)) + assert.Equal( + t, + "{\"incremental\":[{\"data\":{\"name\":\"test\"},\"hasNext\":false}],\"hasNext\":false}\r\n", + readLine(br), + ) + + assert.Equal(t, "--graphql--\r\n", readLine(br)) + + _, err = br.ReadByte() + assert.Equal(t, err, io.EOF) + + }) }