From 924a2c663ef5c7013e398efd7e2c4805fc0f2cc1 Mon Sep 17 00:00:00 2001 From: Owen Cabalceta Date: Thu, 2 May 2024 12:52:34 -0400 Subject: [PATCH 01/10] feat: implement configurable ping timeout for `nhooyr.io/websocket` #83 - added `pingTimeout time.Duration` to `internal/nhooyr.io/websocket/conn.go`'s `Conn struct` - implemented `SetPingTimeout`, following how `nhooyr.io/websocket` exposes configurable conn values - if `conn`'s `pingTimeout` is nonpositive, then use `handleControl`'s 5 second timeout https://github.com/xmidt-org/xmidt-agent/blob/78dffe0cad394ab82f581940e3a8117f04941077/internal/nhooyr.io/websocket/read.go#L301 --- internal/nhooyr.io/websocket/conn.go | 2 + internal/nhooyr.io/websocket/read.go | 12 ++++ internal/nhooyr.io/websocket/write.go | 7 +++ internal/websocket/e2e_test.go | 82 +++++++++++++++++++++++++++ internal/websocket/ws.go | 11 +++- 5 files changed, 113 insertions(+), 1 deletion(-) diff --git a/internal/nhooyr.io/websocket/conn.go b/internal/nhooyr.io/websocket/conn.go index 9ca05ca..43cd29b 100644 --- a/internal/nhooyr.io/websocket/conn.go +++ b/internal/nhooyr.io/websocket/conn.go @@ -17,6 +17,7 @@ import ( "strconv" "sync" "sync/atomic" + "time" ) // MessageType represents the type of a WebSocket message. @@ -79,6 +80,7 @@ type Conn struct { closeErr error wroteClose bool + pingTimeout time.Duration pingCounter int32 activePingsMu sync.Mutex activePings map[string]chan<- struct{} diff --git a/internal/nhooyr.io/websocket/read.go b/internal/nhooyr.io/websocket/read.go index 87ab7c0..4f668a7 100644 --- a/internal/nhooyr.io/websocket/read.go +++ b/internal/nhooyr.io/websocket/read.go @@ -95,6 +95,13 @@ func (c *Conn) SetPongListener(f func(context.Context, []byte)) { c.pongListener = f } +// SetPingTimeout sets the maximum time allowed between PINGs for the connection +// before the connection is closed. +// Nonpositive PingTimeout will default to handleControl's 5 second timeout. +func (c *Conn) SetPingTimeout(d time.Duration) { + c.pingTimeout = d +} + // SetReadLimit sets the max number of bytes to read for a single message. // It applies to the Reader and Read methods. // @@ -313,6 +320,11 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) { switch h.opcode { case opPing: + if c.pingTimeout > 0 { + ctx, cancel = context.WithTimeout(ctx, c.pingTimeout) + defer cancel() + } + c.pingListener(ctx, b) return c.writeControl(ctx, opPong, b) case opPong: diff --git a/internal/nhooyr.io/websocket/write.go b/internal/nhooyr.io/websocket/write.go index 6086082..0e0ad0d 100644 --- a/internal/nhooyr.io/websocket/write.go +++ b/internal/nhooyr.io/websocket/write.go @@ -236,6 +236,13 @@ func (mw *msgWriter) close() { func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error { ctx, cancel := context.WithTimeout(ctx, time.Second*5) + switch opcode { + case opPong: + if c.pingTimeout > 0 { + ctx, cancel = context.WithTimeout(ctx, c.pingTimeout) + } + } + defer cancel() _, err := c.writeFrame(ctx, true, false, opcode, p) diff --git a/internal/websocket/e2e_test.go b/internal/websocket/e2e_test.go index fe9e9b9..aa423ce 100644 --- a/internal/websocket/e2e_test.go +++ b/internal/websocket/e2e_test.go @@ -7,6 +7,7 @@ import ( "context" "errors" "fmt" + "net" "net/http" "net/http/httptest" "sync" @@ -358,3 +359,84 @@ func TestEndToEndConnectionIssues(t *testing.T) { assert.True(started) assert.True(msgCnt.Load() > 0, "got message") } + +func TestEndToEndPingTimeout(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + s := httptest.NewServer( + http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + c, err := websocket.Accept(w, r, nil) + require.NoError(err) + defer c.CloseNow() + + assert.Error(c.Ping(context.Background())) + })) + defer s.Close() + + var ( + connectCnt, disconnectCnt, heartbeatCnt atomic.Int64 + got *ws.Websocket + err error + disconnectErrs []error + ) + got, err = ws.New( + ws.URL(s.URL), + ws.DeviceID("mac:112233445566"), + ws.AddHeartbeatListener( + event.HeartbeatListenerFunc( + func(event.Heartbeat) { + heartbeatCnt.Add(1) + })), + ws.AddConnectListener( + event.ConnectListenerFunc( + func(event.Connect) { + connectCnt.Add(1) + })), + ws.AddDisconnectListener( + event.DisconnectListenerFunc( + func(e event.Disconnect) { + disconnectErrs = append(disconnectErrs, e.Err) + disconnectCnt.Add(1) + })), + ws.RetryPolicy(&retry.Config{ + Interval: time.Second, + Multiplier: 2.0, + Jitter: 1.0 / 3.0, + MaxInterval: 341*time.Second + 333*time.Millisecond, + }), + ws.WithIPv4(), + ws.NowFunc(time.Now), + ws.ConnectTimeout(30*time.Second), + ws.FetchURLTimeout(30*time.Second), + ws.MaxMessageBytes(256*1024), + ws.CredentialsDecorator(func(h http.Header) error { + return nil + }), + // Trigger a ping timeout + ws.PingTimeout(time.Nanosecond), + ) + require.NoError(err) + require.NotNil(got) + + got.Start() + time.Sleep(500 * time.Millisecond) + got.Stop() + // heartbeatCnt should be zero due to a ping timeout + assert.Equal(int64(0), heartbeatCnt.Load()) + assert.Greater(connectCnt.Load(), int64(0)) + assert.Greater(disconnectCnt.Load(), int64(0)) + // disconnectErrs should only contain + assert.NotEmpty(disconnectErrs) + // All disconnectErrs errors should be caused by context.DeadlineExceeded + for _, err := range disconnectErrs { + if errors.Is(err, net.ErrClosed) { + // net.ErrClosed may occur during testing, don't count them + continue + } + + assert.ErrorIs(err, context.DeadlineExceeded) + } + +} diff --git a/internal/websocket/ws.go b/internal/websocket/ws.go index ee3baf0..8f64d1d 100644 --- a/internal/websocket/ws.go +++ b/internal/websocket/ws.go @@ -233,7 +233,11 @@ func (ws *Websocket) run(ctx context.Context) { // Store the connection so writing can take place. ws.m.Lock() ws.conn = conn - ws.conn.SetPingListener((func(context.Context, []byte) { + ws.conn.SetPingListener((func(ctx context.Context, b []byte) { + if ctx.Err() != nil { + return + } + ws.heartbeatListeners.Visit(func(l event.HeartbeatListener) { l.OnHeartbeat(event.Heartbeat{ At: ws.nowFunc(), @@ -242,6 +246,10 @@ func (ws *Websocket) run(ctx context.Context) { }) })) ws.conn.SetPongListener(func(ctx context.Context, b []byte) { + if ctx.Err() != nil { + return + } + ws.heartbeatListeners.Visit(func(l event.HeartbeatListener) { l.OnHeartbeat(event.Heartbeat{ At: ws.nowFunc(), @@ -334,6 +342,7 @@ func (ws *Websocket) dial(ctx context.Context, mode ipMode) (*nhws.Conn, *http.R } conn.SetReadLimit(ws.maxMessageBytes) + conn.SetPingTimeout(ws.pingTimeout) return conn, resp, nil } From 61baf751fb988f9ee4d3f3b67c272947b1226366 Mon Sep 17 00:00:00 2001 From: Owen Cabalceta Date: Thu, 2 May 2024 21:54:52 -0400 Subject: [PATCH 02/10] chore: typo --- internal/websocket/e2e_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/websocket/e2e_test.go b/internal/websocket/e2e_test.go index aa423ce..fcc196a 100644 --- a/internal/websocket/e2e_test.go +++ b/internal/websocket/e2e_test.go @@ -432,7 +432,7 @@ func TestEndToEndPingTimeout(t *testing.T) { // All disconnectErrs errors should be caused by context.DeadlineExceeded for _, err := range disconnectErrs { if errors.Is(err, net.ErrClosed) { - // net.ErrClosed may occur during testing, don't count them + // net.ErrClosed may occur during tests, don't count them continue } From d582cea8b86fa5388002835a47aabf2e40a521ab Mon Sep 17 00:00:00 2001 From: Owen Cabalceta Date: Thu, 2 May 2024 21:55:25 -0400 Subject: [PATCH 03/10] chore: typo --- internal/websocket/e2e_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/websocket/e2e_test.go b/internal/websocket/e2e_test.go index fcc196a..906424f 100644 --- a/internal/websocket/e2e_test.go +++ b/internal/websocket/e2e_test.go @@ -429,7 +429,7 @@ func TestEndToEndPingTimeout(t *testing.T) { assert.Greater(disconnectCnt.Load(), int64(0)) // disconnectErrs should only contain assert.NotEmpty(disconnectErrs) - // All disconnectErrs errors should be caused by context.DeadlineExceeded + // disconnectErrs should only contain context.DeadlineExceeded errors for _, err := range disconnectErrs { if errors.Is(err, net.ErrClosed) { // net.ErrClosed may occur during tests, don't count them From d570a296e4f0019c2710cc836ed42b7ef4677c49 Mon Sep 17 00:00:00 2001 From: Owen Cabalceta Date: Thu, 2 May 2024 21:55:48 -0400 Subject: [PATCH 04/10] chore: typo --- internal/websocket/e2e_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/internal/websocket/e2e_test.go b/internal/websocket/e2e_test.go index 906424f..eed5fb5 100644 --- a/internal/websocket/e2e_test.go +++ b/internal/websocket/e2e_test.go @@ -427,7 +427,6 @@ func TestEndToEndPingTimeout(t *testing.T) { assert.Equal(int64(0), heartbeatCnt.Load()) assert.Greater(connectCnt.Load(), int64(0)) assert.Greater(disconnectCnt.Load(), int64(0)) - // disconnectErrs should only contain assert.NotEmpty(disconnectErrs) // disconnectErrs should only contain context.DeadlineExceeded errors for _, err := range disconnectErrs { From 4280309559814e058ce5be81851ba3ce9e7a9544 Mon Sep 17 00:00:00 2001 From: Owen Cabalceta Date: Thu, 2 May 2024 21:57:30 -0400 Subject: [PATCH 05/10] chore: typo --- internal/websocket/e2e_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/websocket/e2e_test.go b/internal/websocket/e2e_test.go index eed5fb5..1e7d5f9 100644 --- a/internal/websocket/e2e_test.go +++ b/internal/websocket/e2e_test.go @@ -414,7 +414,7 @@ func TestEndToEndPingTimeout(t *testing.T) { ws.CredentialsDecorator(func(h http.Header) error { return nil }), - // Trigger a ping timeout + // Triggers ping timeouts ws.PingTimeout(time.Nanosecond), ) require.NoError(err) From a37fcd7426e9fb688eeaafe3b2a4ba15fb3b2f9b Mon Sep 17 00:00:00 2001 From: Owen Cabalceta Date: Thu, 2 May 2024 21:57:42 -0400 Subject: [PATCH 06/10] chore: typo --- internal/websocket/e2e_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/websocket/e2e_test.go b/internal/websocket/e2e_test.go index 1e7d5f9..829da27 100644 --- a/internal/websocket/e2e_test.go +++ b/internal/websocket/e2e_test.go @@ -423,7 +423,7 @@ func TestEndToEndPingTimeout(t *testing.T) { got.Start() time.Sleep(500 * time.Millisecond) got.Stop() - // heartbeatCnt should be zero due to a ping timeout + // heartbeatCnt should be zero due ping timeouts assert.Equal(int64(0), heartbeatCnt.Load()) assert.Greater(connectCnt.Load(), int64(0)) assert.Greater(disconnectCnt.Load(), int64(0)) From 068f6a6bf2e73150c826192b8896d3170ca24c26 Mon Sep 17 00:00:00 2001 From: Owen Cabalceta Date: Mon, 6 May 2024 13:50:13 -0400 Subject: [PATCH 07/10] chore: simplify `nhooyr.io/websocket/write.go`'s `writeControl` --- internal/nhooyr.io/websocket/close.go | 4 +++- internal/nhooyr.io/websocket/write.go | 11 ----------- 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/internal/nhooyr.io/websocket/close.go b/internal/nhooyr.io/websocket/close.go index 6e62ab6..d19b38a 100644 --- a/internal/nhooyr.io/websocket/close.go +++ b/internal/nhooyr.io/websocket/close.go @@ -156,8 +156,10 @@ func (c *Conn) writeClose(code StatusCode, reason string) error { if ce.Code != StatusNoStatusRcvd { p, marshalErr = ce.bytes() } + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() - writeErr := c.writeControl(context.Background(), opClose, p) + writeErr := c.writeControl(ctx, opClose, p) if CloseStatus(writeErr) != -1 { // Not a real error if it's due to a close frame being received. writeErr = nil diff --git a/internal/nhooyr.io/websocket/write.go b/internal/nhooyr.io/websocket/write.go index 0e0ad0d..b96aefd 100644 --- a/internal/nhooyr.io/websocket/write.go +++ b/internal/nhooyr.io/websocket/write.go @@ -15,7 +15,6 @@ import ( "fmt" "io" "net" - "time" "compress/flate" @@ -235,16 +234,6 @@ func (mw *msgWriter) close() { } func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error { - ctx, cancel := context.WithTimeout(ctx, time.Second*5) - switch opcode { - case opPong: - if c.pingTimeout > 0 { - ctx, cancel = context.WithTimeout(ctx, c.pingTimeout) - } - } - - defer cancel() - _, err := c.writeFrame(ctx, true, false, opcode, p) if err != nil { return fmt.Errorf("failed to write control frame %v: %w", opcode, err) From ec0b16ff092149f1b5c8b127aae2ca1cf18f3dbf Mon Sep 17 00:00:00 2001 From: Owen Cabalceta Date: Mon, 6 May 2024 13:53:18 -0400 Subject: [PATCH 08/10] chore: formatting --- internal/nhooyr.io/websocket/close.go | 1 + 1 file changed, 1 insertion(+) diff --git a/internal/nhooyr.io/websocket/close.go b/internal/nhooyr.io/websocket/close.go index d19b38a..26f7cf4 100644 --- a/internal/nhooyr.io/websocket/close.go +++ b/internal/nhooyr.io/websocket/close.go @@ -156,6 +156,7 @@ func (c *Conn) writeClose(code StatusCode, reason string) error { if ce.Code != StatusNoStatusRcvd { p, marshalErr = ce.bytes() } + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() From e550cdf429f92937ee861c28ea774ea471dd967f Mon Sep 17 00:00:00 2001 From: Owen Cabalceta Date: Mon, 13 May 2024 16:09:12 -0400 Subject: [PATCH 09/10] patch, update pingtimeout e2e test --- internal/websocket/e2e_test.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/internal/websocket/e2e_test.go b/internal/websocket/e2e_test.go index d250c13..b5031a5 100644 --- a/internal/websocket/e2e_test.go +++ b/internal/websocket/e2e_test.go @@ -426,6 +426,9 @@ func TestEndToEndPingTimeout(t *testing.T) { ws.CredentialsDecorator(func(h http.Header) error { return nil }), + ws.ConveyDecorator(func(h http.Header) error { + return nil + }), // Triggers ping timeouts ws.PingTimeout(time.Nanosecond), ) From 0b5cf47575c40e772a97f817d49c78e29d29d65e Mon Sep 17 00:00:00 2001 From: Owen Cabalceta Date: Thu, 16 May 2024 15:37:32 -0400 Subject: [PATCH 10/10] chore: fix tests --- internal/websocket/e2e_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/websocket/e2e_test.go b/internal/websocket/e2e_test.go index 836d64c..648e79d 100644 --- a/internal/websocket/e2e_test.go +++ b/internal/websocket/e2e_test.go @@ -420,7 +420,6 @@ func TestEndToEndPingTimeout(t *testing.T) { }), ws.WithIPv4(), ws.NowFunc(time.Now), - ws.ConnectTimeout(30*time.Second), ws.FetchURLTimeout(30*time.Second), ws.MaxMessageBytes(256*1024), ws.CredentialsDecorator(func(h http.Header) error { @@ -431,6 +430,7 @@ func TestEndToEndPingTimeout(t *testing.T) { }), // Triggers ping timeouts ws.PingTimeout(time.Nanosecond), + ws.HTTPClient(nil), ) require.NoError(err) require.NotNil(got)