diff --git a/_examples/websocket-initfunc/server/server.go b/_examples/websocket-initfunc/server/server.go index 4704708dab0..64c3010f81c 100644 --- a/_examples/websocket-initfunc/server/server.go +++ b/_examples/websocket-initfunc/server/server.go @@ -18,12 +18,12 @@ import ( "github.com/rs/cors" ) -func webSocketInit(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) { +func webSocketInit(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { // Get the token from payload payload := initPayload["authToken"] token, ok := payload.(string) if !ok || token == "" { - return nil, errors.New("authToken not found in transport payload") + return nil, nil, errors.New("authToken not found in transport payload") } // Perform token verification and authentication... @@ -32,7 +32,7 @@ func webSocketInit(ctx context.Context, initPayload transport.InitPayload) (cont // put it in context ctxNew := context.WithValue(ctx, "username", userId) - return ctxNew, nil + return ctxNew, nil, nil } const defaultPort = "8080" @@ -62,7 +62,7 @@ func main() { return true }, }, - InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) { + InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (*transport.InitPayload, context.Context, error) { return webSocketInit(ctx, initPayload) }, }) diff --git a/graphql/handler/transport/websocket.go b/graphql/handler/transport/websocket.go index acd124fe8e5..ed1d9588c9c 100644 --- a/graphql/handler/transport/websocket.go +++ b/graphql/handler/transport/websocket.go @@ -44,7 +44,7 @@ type ( initPayload InitPayload } - WebsocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, error) + WebsocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, *InitPayload, error) WebsocketErrorFunc func(ctx context.Context, err error) // Callback called when websocket is closed. @@ -179,8 +179,10 @@ func (c *wsConnection) init() bool { } } + var initAckPayload *InitPayload = nil if c.InitFunc != nil { - ctx, err := c.InitFunc(c.ctx, c.initPayload) + var ctx context.Context + ctx, initAckPayload, err = c.InitFunc(c.ctx, c.initPayload) if err != nil { c.sendConnectionError(err.Error()) c.close(websocket.CloseNormalClosure, "terminated") @@ -189,7 +191,15 @@ func (c *wsConnection) init() bool { c.ctx = ctx } - c.write(&message{t: connectionAckMessageType}) + if initAckPayload != nil { + initJsonAckPayload, err := json.Marshal(*initAckPayload) + if err != nil { + panic(err) + } + c.write(&message{t: connectionAckMessageType, payload: initJsonAckPayload}) + } else { + c.write(&message{t: connectionAckMessageType}) + } c.write(&message{t: keepAliveMessageType}) case connectionCloseMessageType: c.close(websocket.CloseNormalClosure, "terminated") diff --git a/graphql/handler/transport/websocket_test.go b/graphql/handler/transport/websocket_test.go index 9ac8ad65cc8..678bfeab25c 100644 --- a/graphql/handler/transport/websocket_test.go +++ b/graphql/handler/transport/websocket_test.go @@ -207,8 +207,8 @@ func TestWebsocketInitFunc(t *testing.T) { t.Run("accept connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) { h := testserver.New() h.AddTransport(transport.Websocket{ - InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) { - return context.WithValue(ctx, ckey("newkey"), "newvalue"), nil + InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { + return context.WithValue(ctx, ckey("newkey"), "newvalue"), nil, nil }, }) srv := httptest.NewServer(h) @@ -226,8 +226,8 @@ func TestWebsocketInitFunc(t *testing.T) { t.Run("reject connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) { h := testserver.New() h.AddTransport(transport.Websocket{ - InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) { - return ctx, errors.New("invalid init payload") + InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { + return ctx, nil, errors.New("invalid init payload") }, }) srv := httptest.NewServer(h) @@ -261,8 +261,8 @@ func TestWebsocketInitFunc(t *testing.T) { h := handler.New(es) h.AddTransport(transport.Websocket{ - InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) { - return context.WithValue(ctx, ckey("newkey"), "newvalue"), nil + InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { + return context.WithValue(ctx, ckey("newkey"), "newvalue"), nil, nil }, }) @@ -282,7 +282,7 @@ func TestWebsocketInitFunc(t *testing.T) { h := testserver.New() var cancel func() h.AddTransport(transport.Websocket{ - InitFunc: func(ctx context.Context, _ transport.InitPayload) (newCtx context.Context, _ error) { + InitFunc: func(ctx context.Context, _ transport.InitPayload) (newCtx context.Context, _ *transport.InitPayload, _ error) { newCtx, cancel = context.WithTimeout(transport.AppendCloseReason(ctx, "beep boop"), time.Millisecond*5) return }, @@ -303,6 +303,33 @@ func TestWebsocketInitFunc(t *testing.T) { assert.Equal(t, m.Type, connectionErrorMsg) assert.Equal(t, string(m.Payload), `{"message":"beep boop"}`) }) + t.Run("accept connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) { + h := testserver.New() + h.AddTransport(transport.Websocket{ + InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) { + initResponsePayload := transport.InitPayload{"trackingId": "123-456"} + return context.WithValue(ctx, ckey("newkey"), "newvalue"), &initResponsePayload, nil + }, + }) + srv := httptest.NewServer(h) + defer srv.Close() + + c := wsConnect(srv.URL) + defer c.Close() + + require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg})) + + connAck := readOp(c) + assert.Equal(t, connectionAckMsg, connAck.Type) + + var payload map[string]interface{} + err := json.Unmarshal(connAck.Payload, &payload) + if err != nil { + t.Fatal("Unexpected Error", err) + } + assert.EqualValues(t, "123-456", payload["trackingId"]) + assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type) + }) } func TestWebSocketInitTimeout(t *testing.T) { @@ -382,8 +409,8 @@ func TestWebSocketErrorFunc(t *testing.T) { t.Run("init func errors do not call the error handler", func(t *testing.T) { h := testserver.New() h.AddTransport(transport.Websocket{ - InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, error) { - return ctx, errors.New("this is not what we agreed upon") + InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, *transport.InitPayload, error) { + return ctx, nil, errors.New("this is not what we agreed upon") }, ErrorFunc: func(_ context.Context, err error) { assert.Fail(t, "the error handler got called when it shouldn't have", "error: "+err.Error()) @@ -400,10 +427,10 @@ func TestWebSocketErrorFunc(t *testing.T) { t.Run("init func context closes do not call the error handler", func(t *testing.T) { h := testserver.New() h.AddTransport(transport.Websocket{ - InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, error) { + InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, *transport.InitPayload, error) { newCtx, cancel := context.WithCancel(ctx) time.AfterFunc(time.Millisecond*5, cancel) - return newCtx, nil + return newCtx, nil, nil }, ErrorFunc: func(_ context.Context, err error) { assert.Fail(t, "the error handler got called when it shouldn't have", "error: "+err.Error()) @@ -423,9 +450,9 @@ func TestWebSocketErrorFunc(t *testing.T) { h := testserver.New() var cancel func() h.AddTransport(transport.Websocket{ - InitFunc: func(ctx context.Context, _ transport.InitPayload) (newCtx context.Context, _ error) { + InitFunc: func(ctx context.Context, _ transport.InitPayload) (newCtx context.Context, _ *transport.InitPayload, _ error) { newCtx, cancel = context.WithDeadline(ctx, time.Now().Add(time.Millisecond*5)) - return newCtx, nil + return newCtx, nil, nil }, ErrorFunc: func(_ context.Context, err error) { assert.Fail(t, "the error handler got called when it shouldn't have", "error: "+err.Error()) @@ -477,8 +504,8 @@ func TestWebSocketCloseFunc(t *testing.T) { h := testserver.New() closeFuncCalled := make(chan bool, 1) h.AddTransport(transport.Websocket{ - InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, error) { - return ctx, errors.New("error during init") + InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, *transport.InitPayload, error) { + return ctx, nil, errors.New("error during init") }, CloseFunc: func(_ context.Context, _closeCode int) { closeFuncCalled <- true