From 54bb265b488450f2c41f1070e15784fb9d245258 Mon Sep 17 00:00:00 2001 From: Owen Cabalceta Date: Tue, 2 Apr 2024 13:07:41 -0400 Subject: [PATCH 1/5] feat: improve ws handling --- cmd/xmidt-agent/credentials.go | 2 + internal/websocket/e2e_test.go | 16 ++- internal/websocket/ws.go | 234 +++++++++++++++++++-------------- 3 files changed, 148 insertions(+), 104 deletions(-) diff --git a/cmd/xmidt-agent/credentials.go b/cmd/xmidt-agent/credentials.go index eab9e63..2fca52c 100644 --- a/cmd/xmidt-agent/credentials.go +++ b/cmd/xmidt-agent/credentials.go @@ -43,6 +43,8 @@ func provideCredentials(in credsIn) (*credentials.Credentials, error) { opts := []credentials.Option{ credentials.URL(in.Creds.URL), + // enabling `Required` allows the xmidt-agent to send connect events for auth related errors + credentials.Required(), credentials.HTTPClient(client), credentials.MacAddress(in.ID.DeviceID), credentials.SerialNumber(in.ID.SerialNumber), diff --git a/internal/websocket/e2e_test.go b/internal/websocket/e2e_test.go index 2467cc7..e91af1c 100644 --- a/internal/websocket/e2e_test.go +++ b/internal/websocket/e2e_test.go @@ -23,6 +23,8 @@ import ( ) func TestEndToEnd(t *testing.T) { + var finished bool + assert := assert.New(t) require := require.New(t) @@ -44,6 +46,11 @@ func TestEndToEnd(t *testing.T) { require.NoError(err) mt, got, err := c.Read(ctx) + // server will halt until the websocket closes resulting in a EOF + if finished { + return + } + require.NoError(err) require.Equal(websocket.MessageBinary, mt) require.NotEmpty(got) @@ -51,7 +58,7 @@ func TestEndToEnd(t *testing.T) { err = wrp.NewDecoderBytes(got, wrp.Msgpack).Decode(&msg) require.NoError(err) require.Equal(wrp.SimpleEventMessageType, msg.Type) - require.Equal("server", msg.Source) + require.Equal("client", msg.Source) c.Close(websocket.StatusNormalClosure, "") })) @@ -103,7 +110,7 @@ func TestEndToEnd(t *testing.T) { // Allow multiple calls to start. got.Start() - ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 2000*time.Millisecond) defer cancel() for { @@ -136,8 +143,9 @@ func TestEndToEnd(t *testing.T) { } time.Sleep(10 * time.Millisecond) } - time.Sleep(10 * time.Millisecond) + finished = true got.Stop() + time.Sleep(10 * time.Millisecond) } func TestEndToEndBadData(t *testing.T) { @@ -256,7 +264,7 @@ func TestEndToEndConnectionIssues(t *testing.T) { require.NoError(err) defer c.CloseNow() - ctx, cancel := context.WithTimeout(r.Context(), 2000000*time.Millisecond) + ctx, cancel := context.WithTimeout(r.Context(), 200*time.Millisecond) defer cancel() msg := wrp.Message{ diff --git a/internal/websocket/ws.go b/internal/websocket/ws.go index f2418e9..727a33c 100644 --- a/internal/websocket/ws.go +++ b/internal/websocket/ws.go @@ -6,7 +6,7 @@ package websocket import ( "context" "errors" - "fmt" + "io" "net" "net/http" "sync" @@ -101,7 +101,10 @@ type Websocket struct { wg sync.WaitGroup shutdown context.CancelFunc - conn *nhws.Conn + policy retry.Policy + decoder wrp.Decoder + encoder wrp.Encoder + conn *nhws.Conn } // Option is a functional option type for WS. @@ -152,13 +155,18 @@ func (ws *Websocket) Start() { var ctx context.Context ctx, ws.shutdown = context.WithCancel(context.Background()) + // Init retry policy, but it'll be reset on recurring successful connections. + ws.policy = ws.retryPolicyFactory.NewPolicy(ctx) + ws.decoder = wrp.NewDecoder(nil, wrp.Msgpack) - go ws.run(ctx) + go ws.readPump(ctx) } // Stop stops the websocket connection. func (ws *Websocket) Stop() { ws.m.Lock() + // Avoid the overhead of the close handshake. + ws.conn.Close(nhws.StatusNormalClosure, "") shutdown := ws.shutdown ws.m.Unlock() @@ -183,101 +191,32 @@ func (ws *Websocket) Send(ctx context.Context, msg wrp.Message) error { return err } -func (ws *Websocket) run(ctx context.Context) { +func (ws *Websocket) readPump(ctx context.Context) { ws.wg.Add(1) defer ws.wg.Done() - decoder := wrp.NewDecoder(nil, wrp.Msgpack) - encoder := wrp.NewEncoder(nil, wrp.Msgpack) mode := ws.nextMode(ipv4) - - policy := ws.retryPolicyFactory.NewPolicy(ctx) - + reconnect := true for { - var next time.Duration - - mode = ws.nextMode(mode) - cEvent := event.Connect{ - Started: ws.nowFunc(), - Mode: mode.ToEvent(), + var dialErr error + if reconnect { + dialErr = ws.dial(ctx, mode) } - // If auth fails, then continue with openfail xmidt connection - ws.credDecorator(ws.additionalHeaders) - - conn, _, dialErr := ws.dial(ctx, mode) //nolint:bodyclose - cEvent.At = ws.nowFunc() - if dialErr == nil { - ws.connectListeners.Visit(func(l event.ConnectListener) { - l.OnConnect(cEvent) - }) - - // Reset the retry policy on a successful connection. - policy = ws.retryPolicyFactory.NewPolicy(ctx) - - // Store the connection so writing can take place. - ws.m.Lock() - ws.conn = conn - ws.m.Unlock() - // Read loop for { - var msg wrp.Message - typ, reader, err := ws.conn.Reader(ctx) - if err == nil { - if typ != nhws.MessageBinary { - err = ErrInvalidMsgType - } else { - decoder.Reset(reader) - err = decoder.Decode(&msg) - } - } - + msg, err := ws.readMsg(ctx, mode) + // If a reconnect was attempted but failed, ErrClosed will be found + // in this error list and a reconnect should be attempted again. + reconnect = errors.Is(err, ErrClosed) if err != nil { - ws.m.Lock() - ws.conn = nil - ws.m.Unlock() - - // The websocket gave us an unexpected message, or a message - // that could not be decoded. Close & reconnect. - _ = conn.Close(nhws.StatusUnsupportedData, limit(err.Error())) - - dEvent := event.Disconnect{ - At: ws.nowFunc(), - Err: err, - } - ws.disconnectListeners.Visit(func(l event.DisconnectListener) { - l.OnDisconnect(dEvent) - }) - break } ws.msgListeners.Visit(func(l event.MsgListener) { - l.OnMessage(msg) + l.OnMessage(*msg) }) - - // TODO - This section simply sends back the received wrp msg as a respond to the client's request. This will be replaced - var frameContents []byte - - // if the request was in a format other than Msgpack, or if the caller did not pass - // Contents, then do the encoding here. - encoder.ResetBytes(&frameContents) - err = encoder.Encode(msg) - encoder.ResetBytes(&emptyBuffer) - if err != nil { - ws.disconnectListeners.Visit(func(l event.DisconnectListener) { - l.OnDisconnect(event.Disconnect{ - At: ws.nowFunc(), - Err: fmt.Errorf("xmidt-agent failed to response to wrp message: %s", err), - }) - }) - - continue - } - - ws.conn.Write(ctx, nhws.MessageBinary, frameContents) } } @@ -285,15 +224,8 @@ func (ws *Websocket) run(ctx context.Context) { return } - next, _ = policy.Next() - - if dialErr != nil { - cEvent.Err = dialErr - cEvent.RetryingAt = ws.nowFunc().Add(next) - ws.connectListeners.Visit(func(l event.ConnectListener) { - l.OnConnect(cEvent) - }) - } + mode = ws.nextMode(mode) + next, _ := ws.policy.Next() select { case <-time.After(next): @@ -303,15 +235,121 @@ func (ws *Websocket) run(ctx context.Context) { } } -func (ws *Websocket) dial(ctx context.Context, mode ipMode) (*nhws.Conn, *http.Response, error) { +func (ws *Websocket) readMsg(ctx context.Context, mode ipMode) (msg *wrp.Message, err error) { + defer func() { + if err != nil { + // The websocket either failed to read, gave us an unexpected message or a message + // that could not be decoded. Attempt to reconnect. + // If the reconnect fails, ErrClosed will be added to the error list allowing downstream to attempt a reconnect. + // Otherwise, ErrClosed will not be added to the error list. + err = errors.Join(err, ws.dial(ctx, mode)) + } + }() + + var ( + typ nhws.MessageType + reader io.Reader + ) + typ, reader, err = ws.conn.Reader(ctx) + + if err == nil { + if typ != nhws.MessageBinary { + err = ErrInvalidMsgType + } else { + ws.decoder.Reset(reader) + msg = &wrp.Message{} + err = ws.decoder.Decode(msg) + } + } + + if err != nil { + // The websocket gave us an unexpected message, or a message + // that could not be decoded. Close and send a disconnect event. + _ = ws.conn.Close(nhws.StatusUnsupportedData, limit(err.Error())) + ws.disconnectListeners.Visit(func(l event.DisconnectListener) { + l.OnDisconnect(event.Disconnect{ + At: ws.nowFunc(), + Err: err, + }) + }) + } + + return +} + +func (ws *Websocket) dial(ctx context.Context, mode ipMode) (err error) { + var ( + conn *nhws.Conn + resp *http.Response + ) + + ws.m.Lock() + defer ws.m.Unlock() + defer func() { + // Reconnect was successful, store the connection and send connect event. + if err == nil { + if resp.Body != nil { + resp.Body.Close() + } + + ws.conn = conn + // Reset the retry policy on a successful connection. + ws.policy = ws.retryPolicyFactory.NewPolicy(ctx) + conn.SetReadLimit(ws.maxMessageBytes) + ws.connectListeners.Visit(func(l event.ConnectListener) { + l.OnConnect(event.Connect{ + Started: ws.nowFunc(), + Mode: mode.ToEvent(), + At: ws.nowFunc(), + }) + }) + + return + } + + next, _ := ws.policy.Next() + // Send a connect event with the error that caused the failed connection + // (it'll never be an auth error). + ws.connectListeners.Visit(func(l event.ConnectListener) { + l.OnConnect(event.Connect{ + Started: ws.nowFunc(), + Mode: mode.ToEvent(), + At: ws.nowFunc(), + Err: err, + RetryingAt: ws.nowFunc().Add(next), + }) + }) + + // Failed to reconnect, add a ErrClosed to the error list allowing downstream to attempt a reconnect. + err = errors.Join(err, ErrClosed) + }() + fetchCtx, cancel := context.WithTimeout(ctx, ws.urlFetchingTimeout) defer cancel() - url, err := ws.urlFetcher(fetchCtx) + + var url string + url, err = ws.urlFetcher(fetchCtx) if err != nil { - return nil, nil, err + return } - conn, resp, err := nhws.Dial(ctx, url, + // If auth fails, then continue with an openfail (no themis token) xmidt connection. + // An auth error will never trigger an attempt to reconnect. + if err := ws.credDecorator(ws.additionalHeaders); err != nil { + next, _ := ws.policy.Next() + // Send a connect event with the auth error. + ws.connectListeners.Visit(func(l event.ConnectListener) { + l.OnConnect(event.Connect{ + Started: ws.nowFunc(), + Mode: mode.ToEvent(), + At: ws.nowFunc(), + Err: err, + RetryingAt: ws.nowFunc().Add(next), + }) + }) + } + + conn, resp, err = nhws.Dial(ctx, url, &nhws.DialOptions{ HTTPHeader: ws.additionalHeaders, HTTPClient: &http.Client{ @@ -320,12 +358,8 @@ func (ws *Websocket) dial(ctx context.Context, mode ipMode) (*nhws.Conn, *http.R }, }, ) - if err != nil { - return nil, resp, err - } - conn.SetReadLimit(ws.maxMessageBytes) - return conn, resp, nil + return } type custRT struct { From d5f571053cab5790359914158245819ab1381d71 Mon Sep 17 00:00:00 2001 From: Owen Cabalceta Date: Tue, 2 Apr 2024 18:49:19 -0400 Subject: [PATCH 2/5] chore: fix tests --- internal/websocket/e2e_test.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/internal/websocket/e2e_test.go b/internal/websocket/e2e_test.go index e91af1c..209eec2 100644 --- a/internal/websocket/e2e_test.go +++ b/internal/websocket/e2e_test.go @@ -92,7 +92,6 @@ func TestEndToEnd(t *testing.T) { Jitter: 1.0 / 3.0, MaxInterval: 341*time.Second + 333*time.Millisecond, }), - ws.WithIPv6(), ws.WithIPv4(), ws.NowFunc(time.Now), ws.ConnectTimeout(30*time.Second), @@ -110,7 +109,7 @@ func TestEndToEnd(t *testing.T) { // Allow multiple calls to start. got.Start() - ctx, cancel := context.WithTimeout(context.Background(), 2000*time.Millisecond) + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) defer cancel() for { From 77872d04f2fa3727349ff329eb933f608ed0a80a Mon Sep 17 00:00:00 2001 From: Owen Cabalceta Date: Tue, 2 Apr 2024 21:18:42 -0400 Subject: [PATCH 3/5] chore: update ws.nextMode & tests --- internal/websocket/ws.go | 34 +++++++++++++++++----------------- internal/websocket/ws_test.go | 3 ++- 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/internal/websocket/ws.go b/internal/websocket/ws.go index 727a33c..482f0de 100644 --- a/internal/websocket/ws.go +++ b/internal/websocket/ws.go @@ -101,6 +101,7 @@ type Websocket struct { wg sync.WaitGroup shutdown context.CancelFunc + mode ipMode policy retry.Policy decoder wrp.Decoder encoder wrp.Encoder @@ -155,11 +156,12 @@ func (ws *Websocket) Start() { var ctx context.Context ctx, ws.shutdown = context.WithCancel(context.Background()) + ws.mode = ipv4 // Init retry policy, but it'll be reset on recurring successful connections. ws.policy = ws.retryPolicyFactory.NewPolicy(ctx) ws.decoder = wrp.NewDecoder(nil, wrp.Msgpack) - go ws.readPump(ctx) + go ws.read(ctx) } // Stop stops the websocket connection. @@ -191,22 +193,21 @@ func (ws *Websocket) Send(ctx context.Context, msg wrp.Message) error { return err } -func (ws *Websocket) readPump(ctx context.Context) { +func (ws *Websocket) read(ctx context.Context) { ws.wg.Add(1) defer ws.wg.Done() - mode := ws.nextMode(ipv4) reconnect := true for { var dialErr error if reconnect { - dialErr = ws.dial(ctx, mode) + dialErr = ws.dial(ctx) } if dialErr == nil { // Read loop for { - msg, err := ws.readMsg(ctx, mode) + msg, err := ws.readMsg(ctx) // If a reconnect was attempted but failed, ErrClosed will be found // in this error list and a reconnect should be attempted again. reconnect = errors.Is(err, ErrClosed) @@ -224,7 +225,6 @@ func (ws *Websocket) readPump(ctx context.Context) { return } - mode = ws.nextMode(mode) next, _ := ws.policy.Next() select { @@ -235,14 +235,14 @@ func (ws *Websocket) readPump(ctx context.Context) { } } -func (ws *Websocket) readMsg(ctx context.Context, mode ipMode) (msg *wrp.Message, err error) { +func (ws *Websocket) readMsg(ctx context.Context) (msg *wrp.Message, err error) { defer func() { if err != nil { // The websocket either failed to read, gave us an unexpected message or a message // that could not be decoded. Attempt to reconnect. // If the reconnect fails, ErrClosed will be added to the error list allowing downstream to attempt a reconnect. // Otherwise, ErrClosed will not be added to the error list. - err = errors.Join(err, ws.dial(ctx, mode)) + err = errors.Join(err, ws.dial(ctx)) } }() @@ -277,8 +277,9 @@ func (ws *Websocket) readMsg(ctx context.Context, mode ipMode) (msg *wrp.Message return } -func (ws *Websocket) dial(ctx context.Context, mode ipMode) (err error) { +func (ws *Websocket) dial(ctx context.Context) (err error) { var ( + mode = ws.nextMode() conn *nhws.Conn resp *http.Response ) @@ -372,6 +373,7 @@ func (rt *custRT) RoundTrip(r *http.Request) (*http.Response, error) { // getRT returns a custom RoundTripper for the WS connection. func (ws *Websocket) getRT(mode ipMode) *custRT { + dialer := &net.Dialer{ Timeout: ws.connectTimeout, KeepAlive: ws.keepAliveInterval, @@ -394,16 +396,14 @@ func (ws *Websocket) getRT(mode ipMode) *custRT { } } -func (ws *Websocket) nextMode(mode ipMode) ipMode { - if mode == ipv4 && ws.withIPv6 { - return ipv6 - } - - if mode == ipv6 && ws.withIPv4 { - return ipv4 +func (ws *Websocket) nextMode() ipMode { + if ws.mode == ipv4 && ws.withIPv6 { + ws.mode = ipv6 + } else if ws.mode == ipv6 && ws.withIPv4 { + ws.mode = ipv4 } - return mode + return ws.mode } func limit(s string) string { diff --git a/internal/websocket/ws_test.go b/internal/websocket/ws_test.go index b1d3024..5c24080 100644 --- a/internal/websocket/ws_test.go +++ b/internal/websocket/ws_test.go @@ -366,9 +366,10 @@ func TestNextMode(t *testing.T) { URL("http://example.com"), ) got, err := New(opts...) + got.mode = tc.mode require.NoError(err) require.NotNil(got) - assert.Equal(tc.expected, got.nextMode(tc.mode)) + assert.Equal(tc.expected, got.nextMode()) }) } } From 239e04cb796ca419cb75a5f70ba719131ec02057 Mon Sep 17 00:00:00 2001 From: Owen Cabalceta Date: Wed, 3 Apr 2024 06:30:08 -0400 Subject: [PATCH 4/5] chore: update based on pr --- internal/websocket/e2e_test.go | 40 +++++++++++++++++++++++----------- 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/internal/websocket/e2e_test.go b/internal/websocket/e2e_test.go index 209eec2..22d36ac 100644 --- a/internal/websocket/e2e_test.go +++ b/internal/websocket/e2e_test.go @@ -15,6 +15,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" "github.com/xmidt-org/retry" "github.com/xmidt-org/wrp-go/v3" ws "github.com/xmidt-org/xmidt-agent/internal/websocket" @@ -22,11 +23,22 @@ import ( "nhooyr.io/websocket" ) -func TestEndToEnd(t *testing.T) { - var finished bool +type EndToEndTestSuite struct { + suite.Suite + finished bool + websocket *ws.Websocket +} - assert := assert.New(t) - require := require.New(t) +// Make sure that VariableThatShouldStartAtFive is set to five +// before each test +func (ts *EndToEndTestSuite) AfterTest(suiteName, testName string) { + ts.finished = true + ts.websocket.Stop() +} + +func (ts *EndToEndTestSuite) TestEndToEnd() { + assert := assert.New(ts.T()) + require := require.New(ts.T()) s := httptest.NewServer( http.HandlerFunc( @@ -47,7 +59,7 @@ func TestEndToEnd(t *testing.T) { mt, got, err := c.Read(ctx) // server will halt until the websocket closes resulting in a EOF - if finished { + if ts.finished { return } @@ -66,7 +78,7 @@ func TestEndToEnd(t *testing.T) { var msgCnt, connectCnt, disconnectCnt atomic.Int64 - got, err := ws.New( + websocket, err := ws.New( ws.URL(s.URL), ws.DeviceID("mac:112233445566"), ws.AddMessageListener( @@ -102,12 +114,13 @@ func TestEndToEnd(t *testing.T) { }), ) require.NoError(err) - require.NotNil(got) + require.NotNil(websocket) - got.Start() + ts.websocket = websocket + ts.websocket.Start() // Allow multiple calls to start. - got.Start() + ts.websocket.Start() ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) defer cancel() @@ -124,7 +137,7 @@ func TestEndToEnd(t *testing.T) { } } - got.Send(context.Background(), + ts.websocket.Send(context.Background(), wrp.Message{ Type: wrp.SimpleEventMessageType, Source: "client", @@ -142,9 +155,10 @@ func TestEndToEnd(t *testing.T) { } time.Sleep(10 * time.Millisecond) } - finished = true - got.Stop() - time.Sleep(10 * time.Millisecond) +} + +func TestEndToEnd(t *testing.T) { + suite.Run(t, new(EndToEndTestSuite)) } func TestEndToEndBadData(t *testing.T) { From ba86d4af8232ceb4d762c89cc773365a47eb49ab Mon Sep 17 00:00:00 2001 From: Owen Cabalceta Date: Wed, 3 Apr 2024 07:07:10 -0400 Subject: [PATCH 5/5] chore: improve test --- internal/websocket/e2e_test.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/internal/websocket/e2e_test.go b/internal/websocket/e2e_test.go index 22d36ac..a05f7cd 100644 --- a/internal/websocket/e2e_test.go +++ b/internal/websocket/e2e_test.go @@ -5,6 +5,7 @@ package websocket_test import ( "context" + "errors" "fmt" "net/http" "net/http/httptest" @@ -59,7 +60,9 @@ func (ts *EndToEndTestSuite) TestEndToEnd() { mt, got, err := c.Read(ctx) // server will halt until the websocket closes resulting in a EOF - if ts.finished { + var closeErr websocket.CloseError + if ts.finished && errors.As(err, &closeErr) { + assert.Equal(closeErr.Code, websocket.StatusNormalClosure) return }