Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

multipart/mixed transport support for deferred queries #3341

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions graphql/handler/testserver/testserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"strings"
"time"

"github.com/vektah/gqlparser/v2"
Expand Down Expand Up @@ -43,6 +44,32 @@ func New() *TestServer {
switch rc.Operation.Operation {
case ast.Query:
ran := false
// If the query contains @defer, we will mimick a deferred response.
if strings.Contains(rc.RawQuery, "@defer") {
initialResponse := true
return func(context context.Context) *graphql.Response {
select {
case <-ctx.Done():
return nil
case <-next:
if initialResponse {
initialResponse = false
hasNext := true
return &graphql.Response{
Data: []byte(`{"name":null}`),
HasNext: &hasNext,
}
}
hasNext := false
return &graphql.Response{
Data: []byte(`{"name":"test"}`),
HasNext: &hasNext,
}
case <-completeSubscription:
return nil
}
}
}
return func(ctx context.Context) *graphql.Response {
if ran {
return nil
Expand All @@ -59,9 +86,10 @@ func New() *TestServer {
},
},
})
res, err := graphql.GetOperationContext(ctx).ResolverMiddleware(ctx, func(ctx context.Context) (any, error) {
return &graphql.Response{Data: []byte(`{"name":"test"}`)}, nil
})
res, err := graphql.GetOperationContext(ctx).
ResolverMiddleware(ctx, func(ctx context.Context) (any, error) {
return &graphql.Response{Data: []byte(`{"name":"test"}`)}, nil
})
if err != nil {
panic(err)
}
Expand Down
160 changes: 160 additions & 0 deletions graphql/handler/transport/http_multipart_mixed.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
package transport

import (
"encoding/json"
"fmt"
"io"
"log"
"mime"
"net/http"
"strings"

"github.com/vektah/gqlparser/v2/gqlerror"

"github.com/99designs/gqlgen/graphql"
)

// MultipartMixed is a transport that supports the multipart/mixed spec
type MultipartMixed struct {
Boundary string
}

var _ graphql.Transport = MultipartMixed{}

// Supports checks if the request supports the multipart/mixed spec
// Might be worth check the spec required, but Apollo Client mislabel the spec in the headers.
func (t MultipartMixed) Supports(r *http.Request) bool {
if !strings.Contains(r.Header.Get("Accept"), "multipart/mixed") {
return false
}
mediaType, _, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
if err != nil {
return false
}
return r.Method == http.MethodPost && mediaType == "application/json"
}

// Do implements the multipart/mixed spec as a multipart/mixed response
func (t MultipartMixed) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
// Implements the multipart/mixed spec as a multipart/mixed response:
// * https://github.com/graphql/graphql-wg/blob/e4ef5f9d5997815d9de6681655c152b6b7838b4c/rfcs/DeferStream.md
// 2022/08/23 as implemented by gqlgen.
// * https://github.com/graphql/graphql-wg/blob/f22ea7748c6ebdf88fdbf770a8d9e41984ebd429/rfcs/DeferStream.md June 2023 Spec for the
// `incremental` field
// Follows the format that is used in the Apollo Client tests:
// https://github.com/apollographql/apollo-client/blob/v3.11.8/src/link/http/__tests__/responseIterator.ts#L68
// Apollo Client, despite mentioning in its requests that they require the 2022 spec, it wants the
// `incremental` field to be an array of responses, not a single response. Theoretically we could
// batch responses in the `incremental` field, if we wanted to optimize this code.
ctx := r.Context()
flusher, ok := w.(http.Flusher)
if !ok {
SendErrorf(w, http.StatusInternalServerError, "streaming unsupported")
return
}
defer flusher.Flush()

w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
// This header will be replaced below, but it's required in case we return errors.
w.Header().Set("Content-Type", "application/json")

boundary := t.Boundary
if boundary == "" {
boundary = "graphql"
}

params := &graphql.RawParams{}
start := graphql.Now()
params.Headers = r.Header
params.ReadTime = graphql.TraceTiming{
Start: start,
End: graphql.Now(),
}

bodyString, err := getRequestBody(r)
if err != nil {
gqlErr := gqlerror.Errorf("could not get json request body: %+v", err)
resp := exec.DispatchError(ctx, gqlerror.List{gqlErr})
log.Printf("could not get json request body: %+v", err.Error())
writeJson(w, resp)
return
}

bodyReader := io.NopCloser(strings.NewReader(bodyString))
if err = jsonDecode(bodyReader, &params); err != nil {
w.WriteHeader(http.StatusBadRequest)
gqlErr := gqlerror.Errorf(
"json request body could not be decoded: %+v body:%s",
err,
bodyString,
)
resp := exec.DispatchError(ctx, gqlerror.List{gqlErr})
log.Printf("decoding error: %+v body:%s", err.Error(), bodyString)
writeJson(w, resp)
return
}

rc, opErr := exec.CreateOperationContext(ctx, params)
ctx = graphql.WithOperationContext(ctx, rc)

// Example of the response format (note the new lines are important!):
// --graphql
// Content-Type: application/json
//
// {"data":{"apps":{"apps":[ .. ],"totalNumApps":161,"__typename":"AppsOutput"}},"hasNext":true}
//
// --graphql
// Content-Type: application/json
//
// {"incremental":[{"data":{"groupAccessCount":0},"label":"test","path":["apps","apps",7],"hasNext":true}],"hasNext":true}

if opErr != nil {
w.WriteHeader(statusFor(opErr))

resp := exec.DispatchError(ctx, opErr)
writeJson(w, resp)
return
}

w.Header().Set(
"Content-Type",
fmt.Sprintf(`multipart/mixed;boundary="%s";deferSpec=20220824`, boundary),
)

responses, ctx := exec.DispatchOperation(ctx, rc)
initialResponse := true
for {
response := responses(ctx)
if response == nil {
break
}

fmt.Fprintf(w, "--%s\r\n", boundary)
fmt.Fprintf(w, "Content-Type: application/json\r\n\r\n")

if initialResponse {
writeJson(w, response)
initialResponse = false
} else {
writeIncrementalJson(w, response, response.HasNext)
}
fmt.Fprintf(w, "\r\n\r\n")
flusher.Flush()
}
}

func writeIncrementalJson(w io.Writer, response *graphql.Response, hasNext *bool) {
// TODO: Remove this wrapper on response once gqlgen supports the 2023 spec
b, err := json.Marshal(struct {
Incremental []graphql.Response `json:"incremental"`
HasNext *bool `json:"hasNext"`
}{
Incremental: []graphql.Response{*response},
HasNext: hasNext,
})
if err != nil {
panic(err)
}
w.Write(b)
}
168 changes: 168 additions & 0 deletions graphql/handler/transport/http_multipart_mixed_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
package transport_test

import (
"bufio"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/99designs/gqlgen/graphql/handler/testserver"
"github.com/99designs/gqlgen/graphql/handler/transport"
)

func TestMultipartMixed(t *testing.T) {
initialize := func() *testserver.TestServer {
h := testserver.New()
h.AddTransport(transport.MultipartMixed{})
return h
}

initializeWithServer := func() (*testserver.TestServer, *httptest.Server) {
h := initialize()
return h, httptest.NewServer(h)
}

createHTTPRequest := func(url string, query string) *http.Request {
req, err := http.NewRequest("POST", url, strings.NewReader(query))
require.NoError(t, err, "Request threw error -> %s", err)
req.Header.Set("Accept", "multipart/mixed")
req.Header.Set("content-type", "application/json; charset=utf-8")
return req
}

doRequest := func(handler http.Handler, target, body string) *httptest.ResponseRecorder {
r := createHTTPRequest(target, body)
w := httptest.NewRecorder()

handler.ServeHTTP(w, r)
return w
}

t.Run("decode failure", func(t *testing.T) {
handler, srv := initializeWithServer()
resp := doRequest(handler, srv.URL, "notjson")
assert.Equal(t, http.StatusBadRequest, resp.Code, resp.Body.String())
assert.Equal(t, "application/json", resp.Header().Get("Content-Type"))
assert.Equal(
t,
`{"errors":[{"message":"json request body could not be decoded: invalid character 'o' in literal null (expecting 'u') body:notjson"}],"data":null}`,
resp.Body.String(),
)
})

t.Run("parse failure", func(t *testing.T) {
handler, srv := initializeWithServer()
resp := doRequest(handler, srv.URL, `{"query": "!"}`)
assert.Equal(t, http.StatusUnprocessableEntity, resp.Code, resp.Body.String())
assert.Equal(t, "application/json", resp.Header().Get("Content-Type"))
assert.Equal(
t,
`{"errors":[{"message":"Unexpected !","locations":[{"line":1,"column":1}],"extensions":{"code":"GRAPHQL_PARSE_FAILED"}}],"data":null}`,
resp.Body.String(),
)
})

t.Run("validation failure", func(t *testing.T) {
handler, srv := initializeWithServer()
resp := doRequest(handler, srv.URL, `{"query": "{ title }"}`)
assert.Equal(t, http.StatusUnprocessableEntity, resp.Code, resp.Body.String())
assert.Equal(t, "application/json", resp.Header().Get("Content-Type"))
assert.Equal(
t,
`{"errors":[{"message":"Cannot query field \"title\" on type \"Query\".","locations":[{"line":1,"column":3}],"extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}}],"data":null}`,
resp.Body.String(),
)
})

t.Run("invalid variable", func(t *testing.T) {
handler, srv := initializeWithServer()
resp := doRequest(handler, srv.URL,
`{"query": "query($id:Int!){find(id:$id)}","variables":{"id":false}}`,
)
assert.Equal(t, http.StatusUnprocessableEntity, resp.Code, resp.Body.String())
assert.Equal(t, "application/json", resp.Header().Get("Content-Type"))
assert.Equal(
t,
`{"errors":[{"message":"cannot use bool as Int","path":["variable","id"],"extensions":{"code":"GRAPHQL_VALIDATION_FAILED"}}],"data":null}`,
resp.Body.String(),
)
})

readLine := func(br *bufio.Reader) string {
bs, err := br.ReadString('\n')
require.NoError(t, err)
return bs
}

t.Run("initial and incremental patches", func(t *testing.T) {
handler, srv := initializeWithServer()
defer srv.Close()

var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
handler.SendNextSubscriptionMessage()
}()

client := &http.Client{}
req := createHTTPRequest(
srv.URL,
`{"query":"query { ... @defer { name } }"}`,
)
res, err := client.Do(req)
require.NoError(t, err, "Request threw error -> %s", err)
defer func() {
require.NoError(t, res.Body.Close())
}()

assert.Equal(t, 200, res.StatusCode, "Request return wrong status -> %d", res.Status)
assert.Equal(t, "keep-alive", res.Header.Get("Connection"))
assert.Contains(t, res.Header.Get("Content-Type"), "multipart/mixed")
assert.Contains(t, res.Header.Get("Content-Type"), `boundary="graphql"`)

br := bufio.NewReader(res.Body)

assert.Equal(t, "--graphql\r\n", readLine(br))
assert.Equal(t, "Content-Type: application/json\r\n", readLine(br))
assert.Equal(t, "\r\n", readLine(br))
assert.Equal(t,
"{\"data\":{\"name\":null},\"hasNext\":true}\r\n",
readLine(br),
)
assert.Equal(t, "\r\n", readLine(br))

wg.Add(1)
go func() {
defer wg.Done()
handler.SendNextSubscriptionMessage()
}()

assert.Equal(t, "--graphql\r\n", readLine(br))
assert.Equal(t, "Content-Type: application/json\r\n", readLine(br))
assert.Equal(t, "\r\n", readLine(br))
assert.Equal(
t,
"{\"incremental\":[{\"data\":{\"name\":\"test\"},\"hasNext\":false}],\"hasNext\":false}\r\n",
readLine(br),
)
assert.Equal(t, "\r\n", readLine(br))

wg.Add(1)
go func() {
defer wg.Done()
handler.SendCompleteSubscriptionMessage()
}()

_, err = br.ReadByte()
assert.Equal(t, err, io.EOF)

wg.Wait()
})
}
Loading