From 7ebde41dbfce79163627037759372513a550a215 Mon Sep 17 00:00:00 2001 From: schmidtw Date: Thu, 4 Apr 2024 19:13:04 -0700 Subject: [PATCH] Add the ability to listen to when pings are recieved and pongs are sent. --- internal/nhooyr.io/websocket/conn.go | 5 +++ internal/nhooyr.io/websocket/conn_test.go | 43 +++++++++++++++++++++++ internal/nhooyr.io/websocket/read.go | 18 ++++++++++ 3 files changed, 66 insertions(+) diff --git a/internal/nhooyr.io/websocket/conn.go b/internal/nhooyr.io/websocket/conn.go index cf7063e..9ca05ca 100644 --- a/internal/nhooyr.io/websocket/conn.go +++ b/internal/nhooyr.io/websocket/conn.go @@ -82,6 +82,8 @@ type Conn struct { pingCounter int32 activePingsMu sync.Mutex activePings map[string]chan<- struct{} + pingListener func(context.Context, []byte) + pongListener func(context.Context, []byte) } type connConfig struct { @@ -112,6 +114,9 @@ func newConn(cfg connConfig) *Conn { closed: make(chan struct{}), activePings: make(map[string]chan<- struct{}), } + // set default ping, pong handler + c.SetPingListener(nil) + c.SetPongListener(nil) c.readMu = newMu(c) c.writeFrameMu = newMu(c) diff --git a/internal/nhooyr.io/websocket/conn_test.go b/internal/nhooyr.io/websocket/conn_test.go index 7037e79..e60f2ef 100644 --- a/internal/nhooyr.io/websocket/conn_test.go +++ b/internal/nhooyr.io/websocket/conn_test.go @@ -96,6 +96,49 @@ func TestConn(t *testing.T) { assert.Contains(t, err, "failed to wait for pong") }) + t.Run("pingHandler", func(t *testing.T) { + tt, c1, c2 := newConnTest(t, nil, nil) + //defer tt.Cleanup() + + var count int + c2.SetPingListener(func(context.Context, []byte) { + count++ + }) + + c1.CloseRead(tt.ctx) + c2.CloseRead(tt.ctx) + + for i := 0; i < 10; i++ { + err := c1.Ping(tt.ctx) + assert.Success(t, err) + } + + err := c1.Close(websocket.StatusNormalClosure, "") + assert.Success(t, err) + assert.Equal(t, "count", 10, count) + }) + + t.Run("pongHandler", func(t *testing.T) { + tt, c1, c2 := newConnTest(t, nil, nil) + //defer tt.t.Cleanup() + + var count int + c1.SetPongListener(func(context.Context, []byte) { + count++ + }) + + c1.CloseRead(tt.ctx) + c2.CloseRead(tt.ctx) + for i := 0; i < 10; i++ { + err := c1.Ping(tt.ctx) + assert.Success(t, err) + } + + err := c1.Close(websocket.StatusNormalClosure, "") + assert.Success(t, err) + assert.Equal(t, "count", 10, count) + }) + t.Run("concurrentWrite", func(t *testing.T) { tt, c1, c2 := newConnTest(t, nil, nil) diff --git a/internal/nhooyr.io/websocket/read.go b/internal/nhooyr.io/websocket/read.go index 3a374b5..87ab7c0 100644 --- a/internal/nhooyr.io/websocket/read.go +++ b/internal/nhooyr.io/websocket/read.go @@ -79,6 +79,22 @@ func (c *Conn) CloseRead(ctx context.Context) context.Context { return ctx } +// SetPingListener calls the provided function when a ping is received. +func (c *Conn) SetPingListener(f func(context.Context, []byte)) { + if f == nil { + f = func(context.Context, []byte) {} + } + c.pingListener = f +} + +// SetPongListener calls the provided function when a pong is sent. +func (c *Conn) SetPongListener(f func(context.Context, []byte)) { + if f == nil { + f = func(context.Context, []byte) {} + } + c.pongListener = f +} + // SetReadLimit sets the max number of bytes to read for a single message. // It applies to the Reader and Read methods. // @@ -297,8 +313,10 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) { switch h.opcode { case opPing: + c.pingListener(ctx, b) return c.writeControl(ctx, opPong, b) case opPong: + c.pongListener(ctx, b) c.activePingsMu.Lock() pong, ok := c.activePings[string(b)] c.activePingsMu.Unlock()