diff --git a/graphql/handler.go b/graphql/handler.go index 4df36117b8..9f03c01dc8 100644 --- a/graphql/handler.go +++ b/graphql/handler.go @@ -93,6 +93,7 @@ type ( Transport interface { Supports(r *http.Request) bool Do(w http.ResponseWriter, r *http.Request, exec GraphExecutor) + String() string } ) diff --git a/graphql/handler/server.go b/graphql/handler/server.go index 644bad8d99..dfcd3931f0 100644 --- a/graphql/handler/server.go +++ b/graphql/handler/server.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "net/http" + "slices" "time" "github.com/vektah/gqlparser/v2/ast" @@ -39,6 +40,8 @@ func NewDefaultServer(es graphql.ExecutableSchema) *Server { }) srv.AddTransport(transport.Options{}) srv.AddTransport(transport.GET{}) + // To enable Server-Sent-Events, must insert here: + // srv.AddTransport(transport.SSE{}) srv.AddTransport(transport.POST{}) srv.AddTransport(transport.MultipartForm{}) @@ -56,6 +59,22 @@ func (s *Server) AddTransport(transport graphql.Transport) { s.transports = append(s.transports, transport) } +func (s *Server) PrependTransport(transport graphql.Transport) { + s.transports = append([]graphql.Transport{transport}, s.transports...) +} + +// AddPrePostTransport is primarily for adding experimental transports like +// SSE so that the usual POST transport doesn't respond +func (s *Server) AddPrePostTransport(transport graphql.Transport) { + postIndex := slices.IndexFunc(s.transports, func(n graphql.Transport) bool { + return n.String() == "POST" + }) + + if postIndex != -1 { + s.transports = slices.Insert(s.transports, postIndex, transport) + } +} + func (s *Server) SetErrorPresenter(f graphql.ErrorPresenterFunc) { s.exec.SetErrorPresenter(f) } @@ -96,6 +115,7 @@ func (s *Server) AroundResponses(f graphql.ResponseMiddleware) { s.exec.AroundResponses(f) } +// Transport choice is first acceptable func (s *Server) getTransport(r *http.Request) graphql.Transport { for _, t := range s.transports { if t.Supports(r) { @@ -119,13 +139,13 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { r = r.WithContext(graphql.StartOperationTrace(r.Context())) - transport := s.getTransport(r) - if transport == nil { + gqlTransport := s.getTransport(r) + if gqlTransport == nil { sendErrorf(w, http.StatusBadRequest, "transport not supported") return } - transport.Do(w, r, s.exec) + gqlTransport.Do(w, r, s.exec) } func sendError(w http.ResponseWriter, code int, errors ...*gqlerror.Error) { diff --git a/graphql/handler/server_test.go b/graphql/handler/server_test.go index 09d117db70..b45872c05d 100644 --- a/graphql/handler/server_test.go +++ b/graphql/handler/server_test.go @@ -166,6 +166,12 @@ func TestErrorServer(t *testing.T) { type panicTransport struct{} +var _ graphql.Transport = panicTransport{} + +func (t panicTransport) String() string { + return "panicTransport" +} + func (t panicTransport) Supports(r *http.Request) bool { return true } diff --git a/graphql/handler/transport/http_form_multipart.go b/graphql/handler/transport/http_form_multipart.go index b9eb5f8f43..9f0bf30c77 100644 --- a/graphql/handler/transport/http_form_multipart.go +++ b/graphql/handler/transport/http_form_multipart.go @@ -28,6 +28,10 @@ type MultipartForm struct { var _ graphql.Transport = MultipartForm{} +func (f MultipartForm) String() string { + return "MultipartForm" +} + func (f MultipartForm) Supports(r *http.Request) bool { if r.Header.Get("Upgrade") != "" { return false diff --git a/graphql/handler/transport/http_form_urlencoded.go b/graphql/handler/transport/http_form_urlencoded.go index f877c2dd26..28d4e42b6e 100644 --- a/graphql/handler/transport/http_form_urlencoded.go +++ b/graphql/handler/transport/http_form_urlencoded.go @@ -21,6 +21,10 @@ type UrlEncodedForm struct { var _ graphql.Transport = UrlEncodedForm{} +func (h UrlEncodedForm) String() string { + return "UrlEncodedForm" +} + func (h UrlEncodedForm) Supports(r *http.Request) bool { if r.Header.Get("Upgrade") != "" { return false diff --git a/graphql/handler/transport/http_get.go b/graphql/handler/transport/http_get.go index 470a0fbec2..436b059d45 100644 --- a/graphql/handler/transport/http_get.go +++ b/graphql/handler/transport/http_get.go @@ -24,6 +24,10 @@ type GET struct { var _ graphql.Transport = GET{} +func (h GET) String() string { + return "GET" +} + func (h GET) Supports(r *http.Request) bool { if r.Header.Get("Upgrade") != "" { return false diff --git a/graphql/handler/transport/http_graphql.go b/graphql/handler/transport/http_graphql.go index 0bad1110de..889744292f 100644 --- a/graphql/handler/transport/http_graphql.go +++ b/graphql/handler/transport/http_graphql.go @@ -23,6 +23,10 @@ type GRAPHQL struct { var _ graphql.Transport = GRAPHQL{} +func (h GRAPHQL) String() string { + return "GRAPHQL" +} + func (h GRAPHQL) Supports(r *http.Request) bool { if r.Header.Get("Upgrade") != "" { return false diff --git a/graphql/handler/transport/http_post.go b/graphql/handler/transport/http_post.go index 985f8db294..38c53b3601 100644 --- a/graphql/handler/transport/http_post.go +++ b/graphql/handler/transport/http_post.go @@ -22,6 +22,10 @@ type POST struct { var _ graphql.Transport = POST{} +func (h POST) String() string { + return "POST" +} + func (h POST) Supports(r *http.Request) bool { if r.Header.Get("Upgrade") != "" { return false diff --git a/graphql/handler/transport/options.go b/graphql/handler/transport/options.go index 5d7f4b881a..d01f94b64e 100644 --- a/graphql/handler/transport/options.go +++ b/graphql/handler/transport/options.go @@ -15,6 +15,10 @@ type Options struct { var _ graphql.Transport = Options{} +func (o Options) String() string { + return "Options" +} + func (o Options) Supports(r *http.Request) bool { return r.Method == "HEAD" || r.Method == "OPTIONS" } diff --git a/graphql/handler/transport/sse.go b/graphql/handler/transport/sse.go index 1d59fdffe5..1502e9c9f6 100644 --- a/graphql/handler/transport/sse.go +++ b/graphql/handler/transport/sse.go @@ -18,6 +18,10 @@ type SSE struct{} var _ graphql.Transport = SSE{} +func (t SSE) String() string { + return "SSE" +} + func (t SSE) Supports(r *http.Request) bool { if !strings.Contains(r.Header.Get("Accept"), "text/event-stream") { return false diff --git a/graphql/handler/transport/websocket.go b/graphql/handler/transport/websocket.go index 32e31c7c75..4be716e20e 100644 --- a/graphql/handler/transport/websocket.go +++ b/graphql/handler/transport/websocket.go @@ -86,6 +86,10 @@ var ( _ error = WebsocketError{} ) +func (t Websocket) String() string { + return "Websocket" +} + func (t Websocket) Supports(r *http.Request) bool { return r.Header.Get("Upgrade") != "" }