Skip to content

Commit

Permalink
Allow WebsocketInitFunc to add payload to Ack (#4)
Browse files Browse the repository at this point in the history
* Allow WebsocketInitFunc to add payload to Ack

The connection ACK message in the protocol for both
graphql-ws and graphql-transport-ws allows for a payload in the
connection ack message.

We really wanted to use this to establish better telemetry in our use of
websockets in graphql.

* Fix lint error in test

* Switch argument ordering.

---------

Co-authored-by: Chris Pride <[email protected]>
  • Loading branch information
telemenar and Chris Pride authored Sep 12, 2023
1 parent f90ac05 commit da137ea
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 22 deletions.
8 changes: 4 additions & 4 deletions _examples/websocket-initfunc/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ import (
"github.com/rs/cors"
)

func webSocketInit(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) {
func webSocketInit(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
// Get the token from payload
payload := initPayload["authToken"]
token, ok := payload.(string)
if !ok || token == "" {
return nil, errors.New("authToken not found in transport payload")
return nil, nil, errors.New("authToken not found in transport payload")
}

// Perform token verification and authentication...
Expand All @@ -32,7 +32,7 @@ func webSocketInit(ctx context.Context, initPayload transport.InitPayload) (cont
// put it in context
ctxNew := context.WithValue(ctx, "username", userId)

return ctxNew, nil
return ctxNew, nil, nil
}

const defaultPort = "8080"
Expand Down Expand Up @@ -62,7 +62,7 @@ func main() {
return true
},
},
InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) {
InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (*transport.InitPayload, context.Context, error) {
return webSocketInit(ctx, initPayload)
},
})
Expand Down
16 changes: 13 additions & 3 deletions graphql/handler/transport/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ type (
initPayload InitPayload
}

WebsocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, error)
WebsocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, *InitPayload, error)
WebsocketErrorFunc func(ctx context.Context, err error)

// Callback called when websocket is closed.
Expand Down Expand Up @@ -179,8 +179,10 @@ func (c *wsConnection) init() bool {
}
}

var initAckPayload *InitPayload = nil
if c.InitFunc != nil {
ctx, err := c.InitFunc(c.ctx, c.initPayload)
var ctx context.Context
ctx, initAckPayload, err = c.InitFunc(c.ctx, c.initPayload)
if err != nil {
c.sendConnectionError(err.Error())
c.close(websocket.CloseNormalClosure, "terminated")
Expand All @@ -189,7 +191,15 @@ func (c *wsConnection) init() bool {
c.ctx = ctx
}

c.write(&message{t: connectionAckMessageType})
if initAckPayload != nil {
initJsonAckPayload, err := json.Marshal(*initAckPayload)
if err != nil {
panic(err)
}
c.write(&message{t: connectionAckMessageType, payload: initJsonAckPayload})
} else {
c.write(&message{t: connectionAckMessageType})
}
c.write(&message{t: keepAliveMessageType})
case connectionCloseMessageType:
c.close(websocket.CloseNormalClosure, "terminated")
Expand Down
57 changes: 42 additions & 15 deletions graphql/handler/transport/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ func TestWebsocketInitFunc(t *testing.T) {
t.Run("accept connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) {
h := testserver.New()
h.AddTransport(transport.Websocket{
InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) {
return context.WithValue(ctx, ckey("newkey"), "newvalue"), nil
InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
return context.WithValue(ctx, ckey("newkey"), "newvalue"), nil, nil
},
})
srv := httptest.NewServer(h)
Expand All @@ -226,8 +226,8 @@ func TestWebsocketInitFunc(t *testing.T) {
t.Run("reject connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) {
h := testserver.New()
h.AddTransport(transport.Websocket{
InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) {
return ctx, errors.New("invalid init payload")
InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
return ctx, nil, errors.New("invalid init payload")
},
})
srv := httptest.NewServer(h)
Expand Down Expand Up @@ -261,8 +261,8 @@ func TestWebsocketInitFunc(t *testing.T) {
h := handler.New(es)

h.AddTransport(transport.Websocket{
InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) {
return context.WithValue(ctx, ckey("newkey"), "newvalue"), nil
InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
return context.WithValue(ctx, ckey("newkey"), "newvalue"), nil, nil
},
})

Expand All @@ -282,7 +282,7 @@ func TestWebsocketInitFunc(t *testing.T) {
h := testserver.New()
var cancel func()
h.AddTransport(transport.Websocket{
InitFunc: func(ctx context.Context, _ transport.InitPayload) (newCtx context.Context, _ error) {
InitFunc: func(ctx context.Context, _ transport.InitPayload) (newCtx context.Context, _ *transport.InitPayload, _ error) {
newCtx, cancel = context.WithTimeout(transport.AppendCloseReason(ctx, "beep boop"), time.Millisecond*5)
return
},
Expand All @@ -303,6 +303,33 @@ func TestWebsocketInitFunc(t *testing.T) {
assert.Equal(t, m.Type, connectionErrorMsg)
assert.Equal(t, string(m.Payload), `{"message":"beep boop"}`)
})
t.Run("accept connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) {
h := testserver.New()
h.AddTransport(transport.Websocket{
InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
initResponsePayload := transport.InitPayload{"trackingId": "123-456"}
return context.WithValue(ctx, ckey("newkey"), "newvalue"), &initResponsePayload, nil
},
})
srv := httptest.NewServer(h)
defer srv.Close()

c := wsConnect(srv.URL)
defer c.Close()

require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))

connAck := readOp(c)
assert.Equal(t, connectionAckMsg, connAck.Type)

var payload map[string]interface{}
err := json.Unmarshal(connAck.Payload, &payload)
if err != nil {
t.Fatal("Unexpected Error", err)
}
assert.EqualValues(t, "123-456", payload["trackingId"])
assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
})
}

func TestWebSocketInitTimeout(t *testing.T) {
Expand Down Expand Up @@ -382,8 +409,8 @@ func TestWebSocketErrorFunc(t *testing.T) {
t.Run("init func errors do not call the error handler", func(t *testing.T) {
h := testserver.New()
h.AddTransport(transport.Websocket{
InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, error) {
return ctx, errors.New("this is not what we agreed upon")
InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, *transport.InitPayload, error) {
return ctx, nil, errors.New("this is not what we agreed upon")
},
ErrorFunc: func(_ context.Context, err error) {
assert.Fail(t, "the error handler got called when it shouldn't have", "error: "+err.Error())
Expand All @@ -400,10 +427,10 @@ func TestWebSocketErrorFunc(t *testing.T) {
t.Run("init func context closes do not call the error handler", func(t *testing.T) {
h := testserver.New()
h.AddTransport(transport.Websocket{
InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, error) {
InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, *transport.InitPayload, error) {
newCtx, cancel := context.WithCancel(ctx)
time.AfterFunc(time.Millisecond*5, cancel)
return newCtx, nil
return newCtx, nil, nil
},
ErrorFunc: func(_ context.Context, err error) {
assert.Fail(t, "the error handler got called when it shouldn't have", "error: "+err.Error())
Expand All @@ -423,9 +450,9 @@ func TestWebSocketErrorFunc(t *testing.T) {
h := testserver.New()
var cancel func()
h.AddTransport(transport.Websocket{
InitFunc: func(ctx context.Context, _ transport.InitPayload) (newCtx context.Context, _ error) {
InitFunc: func(ctx context.Context, _ transport.InitPayload) (newCtx context.Context, _ *transport.InitPayload, _ error) {
newCtx, cancel = context.WithDeadline(ctx, time.Now().Add(time.Millisecond*5))
return newCtx, nil
return newCtx, nil, nil
},
ErrorFunc: func(_ context.Context, err error) {
assert.Fail(t, "the error handler got called when it shouldn't have", "error: "+err.Error())
Expand Down Expand Up @@ -477,8 +504,8 @@ func TestWebSocketCloseFunc(t *testing.T) {
h := testserver.New()
closeFuncCalled := make(chan bool, 1)
h.AddTransport(transport.Websocket{
InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, error) {
return ctx, errors.New("error during init")
InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, *transport.InitPayload, error) {
return ctx, nil, errors.New("error during init")
},
CloseFunc: func(_ context.Context, _closeCode int) {
closeFuncCalled <- true
Expand Down

0 comments on commit da137ea

Please sign in to comment.