Skip to content

Commit

Permalink
feat: implement configurable ping timeout for nhooyr.io/websocket #88
Browse files Browse the repository at this point in the history
- added `pingTimeout   time.Duration` to `internal/nhooyr.io/websocket/conn.go`'s `Conn struct`
- implemented `SetPingTimeout`, following how `nhooyr.io/websocket` exposes configurable conn values
- if `conn`'s `pingTimeout` is nonpositive, then use `handleControl`'s 5 second timeout https://github.com/xmidt-org/xmidt-agent/blob/78dffe0cad394ab82f581940e3a8117f04941077/internal/nhooyr.io/websocket/read.go#L301
  • Loading branch information
denopink committed May 2, 2024
1 parent 78dffe0 commit 67a106f
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 1 deletion.
2 changes: 2 additions & 0 deletions internal/nhooyr.io/websocket/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"strconv"
"sync"
"sync/atomic"
"time"
)

// MessageType represents the type of a WebSocket message.
Expand Down Expand Up @@ -79,6 +80,7 @@ type Conn struct {
closeErr error
wroteClose bool

pingTimeout time.Duration
pingCounter int32
activePingsMu sync.Mutex
activePings map[string]chan<- struct{}
Expand Down
12 changes: 12 additions & 0 deletions internal/nhooyr.io/websocket/read.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@ func (c *Conn) SetPongListener(f func(context.Context, []byte)) {
c.pongListener = f
}

// SetPingTimeout sets the maximum time allowed between PINGs for the connection
// before the connection is closed.
// Nonpositive PingTimeout will default to handleControl's 5 second timeout.
func (c *Conn) SetPingTimeout(d time.Duration) {
c.pingTimeout = d
}

// SetReadLimit sets the max number of bytes to read for a single message.
// It applies to the Reader and Read methods.
//
Expand Down Expand Up @@ -313,6 +320,11 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) {

switch h.opcode {
case opPing:
if c.pingTimeout > 0 {
ctx, cancel = context.WithTimeout(ctx, c.pingTimeout)
defer cancel()
}

c.pingListener(ctx, b)
return c.writeControl(ctx, opPong, b)
case opPong:
Expand Down
11 changes: 10 additions & 1 deletion internal/websocket/ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,11 @@ func (ws *Websocket) run(ctx context.Context) {
// Store the connection so writing can take place.
ws.m.Lock()
ws.conn = conn
ws.conn.SetPingListener((func(context.Context, []byte) {
ws.conn.SetPingListener((func(ctx context.Context, b []byte) {
if ctx.Err() != nil {
return
}

ws.heartbeatListeners.Visit(func(l event.HeartbeatListener) {
l.OnHeartbeat(event.Heartbeat{
At: ws.nowFunc(),
Expand All @@ -242,6 +246,10 @@ func (ws *Websocket) run(ctx context.Context) {
})
}))
ws.conn.SetPongListener(func(ctx context.Context, b []byte) {
if ctx.Err() != nil {
return
}

ws.heartbeatListeners.Visit(func(l event.HeartbeatListener) {
l.OnHeartbeat(event.Heartbeat{
At: ws.nowFunc(),
Expand Down Expand Up @@ -334,6 +342,7 @@ func (ws *Websocket) dial(ctx context.Context, mode ipMode) (*nhws.Conn, *http.R
}

conn.SetReadLimit(ws.maxMessageBytes)
conn.SetPingTimeout(ws.pingTimeout)
return conn, resp, nil
}

Expand Down

0 comments on commit 67a106f

Please sign in to comment.