Skip to content

Commit

Permalink
Merge pull request #113 from xmidt-org/denopink/feat/configurable-pin…
Browse files Browse the repository at this point in the history
…g-timeout

feat: implement configurable ping timeout for `nhooyr.io/websocket`
  • Loading branch information
denopink authored May 16, 2024
2 parents 42ebea8 + b8e71b6 commit 232dd3a
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 6 deletions.
5 changes: 4 additions & 1 deletion internal/nhooyr.io/websocket/close.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,10 @@ func (c *Conn) writeClose(code StatusCode, reason string) error {
p, marshalErr = ce.bytes()
}

writeErr := c.writeControl(context.Background(), opClose, p)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()

writeErr := c.writeControl(ctx, opClose, p)
if CloseStatus(writeErr) != -1 {
// Not a real error if it's due to a close frame being received.
writeErr = nil
Expand Down
2 changes: 2 additions & 0 deletions internal/nhooyr.io/websocket/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"strconv"
"sync"
"sync/atomic"
"time"
)

// MessageType represents the type of a WebSocket message.
Expand Down Expand Up @@ -79,6 +80,7 @@ type Conn struct {
closeErr error
wroteClose bool

pingTimeout time.Duration
pingCounter int32
activePingsMu sync.Mutex
activePings map[string]chan<- struct{}
Expand Down
12 changes: 12 additions & 0 deletions internal/nhooyr.io/websocket/read.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@ func (c *Conn) SetPongListener(f func(context.Context, []byte)) {
c.pongListener = f
}

// SetPingTimeout sets the maximum time allowed between PINGs for the connection
// before the connection is closed.
// Nonpositive PingTimeout will default to handleControl's 5 second timeout.
func (c *Conn) SetPingTimeout(d time.Duration) {
c.pingTimeout = d
}

// SetReadLimit sets the max number of bytes to read for a single message.
// It applies to the Reader and Read methods.
//
Expand Down Expand Up @@ -313,6 +320,11 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) {

switch h.opcode {
case opPing:
if c.pingTimeout > 0 {
ctx, cancel = context.WithTimeout(ctx, c.pingTimeout)
defer cancel()
}

c.pingListener(ctx, b)
return c.writeControl(ctx, opPong, b)
case opPong:
Expand Down
4 changes: 0 additions & 4 deletions internal/nhooyr.io/websocket/write.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
"fmt"
"io"
"net"
"time"

"compress/flate"

Expand Down Expand Up @@ -235,9 +234,6 @@ func (mw *msgWriter) close() {
}

func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error {
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()

_, err := c.writeFrame(ctx, true, false, opcode, p)
if err != nil {
return fmt.Errorf("failed to write control frame %v: %w", opcode, err)
Expand Down
84 changes: 84 additions & 0 deletions internal/websocket/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"errors"
"fmt"
"net"
"net/http"
"net/http/httptest"
"sync"
Expand Down Expand Up @@ -370,3 +371,86 @@ func TestEndToEndConnectionIssues(t *testing.T) {
assert.True(started)
assert.True(msgCnt.Load() > 0, "got message")
}

func TestEndToEndPingTimeout(t *testing.T) {
assert := assert.New(t)
require := require.New(t)

s := httptest.NewServer(
http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
c, err := websocket.Accept(w, r, nil)
require.NoError(err)
defer c.CloseNow()

assert.Error(c.Ping(context.Background()))
}))
defer s.Close()

var (
connectCnt, disconnectCnt, heartbeatCnt atomic.Int64
got *ws.Websocket
err error
disconnectErrs []error
)
got, err = ws.New(
ws.URL(s.URL),
ws.DeviceID("mac:112233445566"),
ws.AddHeartbeatListener(
event.HeartbeatListenerFunc(
func(event.Heartbeat) {
heartbeatCnt.Add(1)
})),
ws.AddConnectListener(
event.ConnectListenerFunc(
func(event.Connect) {
connectCnt.Add(1)
})),
ws.AddDisconnectListener(
event.DisconnectListenerFunc(
func(e event.Disconnect) {
disconnectErrs = append(disconnectErrs, e.Err)
disconnectCnt.Add(1)
})),
ws.RetryPolicy(&retry.Config{
Interval: time.Second,
Multiplier: 2.0,
Jitter: 1.0 / 3.0,
MaxInterval: 341*time.Second + 333*time.Millisecond,
}),
ws.WithIPv4(),
ws.NowFunc(time.Now),
ws.FetchURLTimeout(30*time.Second),
ws.MaxMessageBytes(256*1024),
ws.CredentialsDecorator(func(h http.Header) error {
return nil
}),
ws.ConveyDecorator(func(h http.Header) error {
return nil
}),
// Triggers ping timeouts
ws.PingTimeout(time.Nanosecond),
ws.HTTPClient(nil),
)
require.NoError(err)
require.NotNil(got)

got.Start()
time.Sleep(500 * time.Millisecond)
got.Stop()
// heartbeatCnt should be zero due ping timeouts
assert.Equal(int64(0), heartbeatCnt.Load())
assert.Greater(connectCnt.Load(), int64(0))
assert.Greater(disconnectCnt.Load(), int64(0))
assert.NotEmpty(disconnectErrs)
// disconnectErrs should only contain context.DeadlineExceeded errors
for _, err := range disconnectErrs {
if errors.Is(err, net.ErrClosed) {
// net.ErrClosed may occur during tests, don't count them
continue
}

assert.ErrorIs(err, context.DeadlineExceeded)
}

}
11 changes: 10 additions & 1 deletion internal/websocket/ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,11 @@ func (ws *Websocket) run(ctx context.Context) {
// Store the connection so writing can take place.
ws.m.Lock()
ws.conn = conn
ws.conn.SetPingListener((func(context.Context, []byte) {
ws.conn.SetPingListener((func(ctx context.Context, b []byte) {
if ctx.Err() != nil {
return
}

ws.heartbeatListeners.Visit(func(l event.HeartbeatListener) {
l.OnHeartbeat(event.Heartbeat{
At: ws.nowFunc(),
Expand All @@ -256,6 +260,10 @@ func (ws *Websocket) run(ctx context.Context) {
})
}))
ws.conn.SetPongListener(func(ctx context.Context, b []byte) {
if ctx.Err() != nil {
return
}

ws.heartbeatListeners.Visit(func(l event.HeartbeatListener) {
l.OnHeartbeat(event.Heartbeat{
At: ws.nowFunc(),
Expand Down Expand Up @@ -346,6 +354,7 @@ func (ws *Websocket) dial(ctx context.Context, mode ipMode) (*nhws.Conn, *http.R
}

conn.SetReadLimit(ws.maxMessageBytes)
conn.SetPingTimeout(ws.pingTimeout)
return conn, resp, nil
}

Expand Down

0 comments on commit 232dd3a

Please sign in to comment.