diff --git a/cmd/xmidt-agent/credentials.go b/cmd/xmidt-agent/credentials.go index eab9e63..2fca52c 100644 --- a/cmd/xmidt-agent/credentials.go +++ b/cmd/xmidt-agent/credentials.go @@ -43,6 +43,8 @@ func provideCredentials(in credsIn) (*credentials.Credentials, error) { opts := []credentials.Option{ credentials.URL(in.Creds.URL), + // enabling `Required` allows the xmidt-agent to send connect events for auth related errors + credentials.Required(), credentials.HTTPClient(client), credentials.MacAddress(in.ID.DeviceID), credentials.SerialNumber(in.ID.SerialNumber), diff --git a/internal/websocket/e2e_test.go b/internal/websocket/e2e_test.go index 2467cc7..14f9028 100644 --- a/internal/websocket/e2e_test.go +++ b/internal/websocket/e2e_test.go @@ -5,7 +5,9 @@ package websocket_test import ( "context" + "errors" "fmt" + "io" "net/http" "net/http/httptest" "sync" @@ -23,6 +25,8 @@ import ( ) func TestEndToEnd(t *testing.T) { + var finished bool + assert := assert.New(t) require := require.New(t) @@ -44,6 +48,11 @@ func TestEndToEnd(t *testing.T) { require.NoError(err) mt, got, err := c.Read(ctx) + // server will halt until the websocket closes resulting in a EOF + if finished && errors.Is(err, io.EOF) { + return + } + require.NoError(err) require.Equal(websocket.MessageBinary, mt) require.NotEmpty(got) @@ -51,7 +60,7 @@ func TestEndToEnd(t *testing.T) { err = wrp.NewDecoderBytes(got, wrp.Msgpack).Decode(&msg) require.NoError(err) require.Equal(wrp.SimpleEventMessageType, msg.Type) - require.Equal("server", msg.Source) + require.Equal("client", msg.Source) c.Close(websocket.StatusNormalClosure, "") })) @@ -138,6 +147,7 @@ func TestEndToEnd(t *testing.T) { } time.Sleep(10 * time.Millisecond) got.Stop() + finished = true } func TestEndToEndBadData(t *testing.T) { @@ -256,7 +266,7 @@ func TestEndToEndConnectionIssues(t *testing.T) { require.NoError(err) defer c.CloseNow() - ctx, cancel := context.WithTimeout(r.Context(), 2000000*time.Millisecond) + ctx, cancel := context.WithTimeout(r.Context(), 200*time.Millisecond) defer cancel() msg := wrp.Message{ diff --git a/internal/websocket/ws.go b/internal/websocket/ws.go index f2418e9..d855a05 100644 --- a/internal/websocket/ws.go +++ b/internal/websocket/ws.go @@ -6,7 +6,7 @@ package websocket import ( "context" "errors" - "fmt" + "io" "net" "net/http" "sync" @@ -101,7 +101,10 @@ type Websocket struct { wg sync.WaitGroup shutdown context.CancelFunc - conn *nhws.Conn + policy retry.Policy + decoder wrp.Decoder + encoder wrp.Encoder + conn *nhws.Conn } // Option is a functional option type for WS. @@ -152,13 +155,18 @@ func (ws *Websocket) Start() { var ctx context.Context ctx, ws.shutdown = context.WithCancel(context.Background()) + // Init retry policy, but it'll be reset on recurring successful connections. + ws.policy = ws.retryPolicyFactory.NewPolicy(ctx) + ws.decoder = wrp.NewDecoder(nil, wrp.Msgpack) - go ws.run(ctx) + go ws.readPump(ctx) } // Stop stops the websocket connection. func (ws *Websocket) Stop() { ws.m.Lock() + // Avoid the overhead of the close handshake. + ws.conn.CloseNow() shutdown := ws.shutdown ws.m.Unlock() @@ -183,101 +191,32 @@ func (ws *Websocket) Send(ctx context.Context, msg wrp.Message) error { return err } -func (ws *Websocket) run(ctx context.Context) { +func (ws *Websocket) readPump(ctx context.Context) { ws.wg.Add(1) defer ws.wg.Done() - decoder := wrp.NewDecoder(nil, wrp.Msgpack) - encoder := wrp.NewEncoder(nil, wrp.Msgpack) mode := ws.nextMode(ipv4) - - policy := ws.retryPolicyFactory.NewPolicy(ctx) - + reconnect := true for { - var next time.Duration - - mode = ws.nextMode(mode) - cEvent := event.Connect{ - Started: ws.nowFunc(), - Mode: mode.ToEvent(), + var dialErr error + if reconnect { + dialErr = ws.dial(ctx, mode) } - // If auth fails, then continue with openfail xmidt connection - ws.credDecorator(ws.additionalHeaders) - - conn, _, dialErr := ws.dial(ctx, mode) //nolint:bodyclose - cEvent.At = ws.nowFunc() - if dialErr == nil { - ws.connectListeners.Visit(func(l event.ConnectListener) { - l.OnConnect(cEvent) - }) - - // Reset the retry policy on a successful connection. - policy = ws.retryPolicyFactory.NewPolicy(ctx) - - // Store the connection so writing can take place. - ws.m.Lock() - ws.conn = conn - ws.m.Unlock() - // Read loop for { - var msg wrp.Message - typ, reader, err := ws.conn.Reader(ctx) - if err == nil { - if typ != nhws.MessageBinary { - err = ErrInvalidMsgType - } else { - decoder.Reset(reader) - err = decoder.Decode(&msg) - } - } - + msg, err := ws.readMsg(ctx, mode) + // If a reconnect was attempted but failed, ErrClosed will be found + // in this error list and a reconnect should be attempted again. + reconnect = errors.Is(err, ErrClosed) if err != nil { - ws.m.Lock() - ws.conn = nil - ws.m.Unlock() - - // The websocket gave us an unexpected message, or a message - // that could not be decoded. Close & reconnect. - _ = conn.Close(nhws.StatusUnsupportedData, limit(err.Error())) - - dEvent := event.Disconnect{ - At: ws.nowFunc(), - Err: err, - } - ws.disconnectListeners.Visit(func(l event.DisconnectListener) { - l.OnDisconnect(dEvent) - }) - break } ws.msgListeners.Visit(func(l event.MsgListener) { - l.OnMessage(msg) + l.OnMessage(*msg) }) - - // TODO - This section simply sends back the received wrp msg as a respond to the client's request. This will be replaced - var frameContents []byte - - // if the request was in a format other than Msgpack, or if the caller did not pass - // Contents, then do the encoding here. - encoder.ResetBytes(&frameContents) - err = encoder.Encode(msg) - encoder.ResetBytes(&emptyBuffer) - if err != nil { - ws.disconnectListeners.Visit(func(l event.DisconnectListener) { - l.OnDisconnect(event.Disconnect{ - At: ws.nowFunc(), - Err: fmt.Errorf("xmidt-agent failed to response to wrp message: %s", err), - }) - }) - - continue - } - - ws.conn.Write(ctx, nhws.MessageBinary, frameContents) } } @@ -285,15 +224,8 @@ func (ws *Websocket) run(ctx context.Context) { return } - next, _ = policy.Next() - - if dialErr != nil { - cEvent.Err = dialErr - cEvent.RetryingAt = ws.nowFunc().Add(next) - ws.connectListeners.Visit(func(l event.ConnectListener) { - l.OnConnect(cEvent) - }) - } + mode = ws.nextMode(mode) + next, _ := ws.policy.Next() select { case <-time.After(next): @@ -303,15 +235,121 @@ func (ws *Websocket) run(ctx context.Context) { } } -func (ws *Websocket) dial(ctx context.Context, mode ipMode) (*nhws.Conn, *http.Response, error) { +func (ws *Websocket) readMsg(ctx context.Context, mode ipMode) (msg *wrp.Message, err error) { + defer func() { + if err != nil { + // The websocket either failed to read, gave us an unexpected message or a message + // that could not be decoded. Attempt to reconnect. + // If the reconnect fails, ErrClosed will be added to the error list allowing downstream to attempt a reconnect. + // Otherwise ErrClosed is not in the error list. + err = errors.Join(err, ws.dial(ctx, mode)) + } + }() + + var ( + typ nhws.MessageType + reader io.Reader + ) + typ, reader, err = ws.conn.Reader(ctx) + + if err == nil { + if typ != nhws.MessageBinary { + err = ErrInvalidMsgType + } else { + ws.decoder.Reset(reader) + msg = &wrp.Message{} + err = ws.decoder.Decode(msg) + } + } + + if err != nil { + // The websocket gave us an unexpected message, or a message + // that could not be decoded. Close and send a disconnect event. + _ = ws.conn.Close(nhws.StatusUnsupportedData, limit(err.Error())) + ws.disconnectListeners.Visit(func(l event.DisconnectListener) { + l.OnDisconnect(event.Disconnect{ + At: ws.nowFunc(), + Err: err, + }) + }) + } + + return +} + +func (ws *Websocket) dial(ctx context.Context, mode ipMode) (err error) { + var ( + conn *nhws.Conn + resp *http.Response + ) + + ws.m.Lock() + defer ws.m.Unlock() + defer func() { + // Reconnect was successful, store the connection and send connect event. + if err == nil { + if resp.Body != nil { + resp.Body.Close() + } + + ws.conn = conn + // Reset the retry policy on a successful connection. + ws.policy = ws.retryPolicyFactory.NewPolicy(ctx) + conn.SetReadLimit(ws.maxMessageBytes) + ws.connectListeners.Visit(func(l event.ConnectListener) { + l.OnConnect(event.Connect{ + Started: ws.nowFunc(), + Mode: mode.ToEvent(), + At: ws.nowFunc(), + }) + }) + + return + } + + next, _ := ws.policy.Next() + // Send a connect event with the error that caused the failed connection + // (it'll never be an auth error). + ws.connectListeners.Visit(func(l event.ConnectListener) { + l.OnConnect(event.Connect{ + Started: ws.nowFunc(), + Mode: mode.ToEvent(), + At: ws.nowFunc(), + Err: err, + RetryingAt: ws.nowFunc().Add(next), + }) + }) + + // Failed to reconnect, add a ErrClosed to the error list allowing downstream to attempt a reconnect. + err = errors.Join(err, ErrClosed) + }() + fetchCtx, cancel := context.WithTimeout(ctx, ws.urlFetchingTimeout) defer cancel() - url, err := ws.urlFetcher(fetchCtx) + + var url string + url, err = ws.urlFetcher(fetchCtx) if err != nil { - return nil, nil, err + return } - conn, resp, err := nhws.Dial(ctx, url, + // If auth fails, then continue with an openfail (no themis token) xmidt connection. + // An auth error will never trigger an attempt to reconnect. + if err := ws.credDecorator(ws.additionalHeaders); err != nil { + next, _ := ws.policy.Next() + // Send a connect event with the auth error. + ws.connectListeners.Visit(func(l event.ConnectListener) { + l.OnConnect(event.Connect{ + Started: ws.nowFunc(), + Mode: mode.ToEvent(), + At: ws.nowFunc(), + Err: err, + RetryingAt: ws.nowFunc().Add(next), + }) + }) + } + + conn, resp, err = nhws.Dial(ctx, url, &nhws.DialOptions{ HTTPHeader: ws.additionalHeaders, HTTPClient: &http.Client{ @@ -320,12 +358,8 @@ func (ws *Websocket) dial(ctx context.Context, mode ipMode) (*nhws.Conn, *http.R }, }, ) - if err != nil { - return nil, resp, err - } - conn.SetReadLimit(ws.maxMessageBytes) - return conn, resp, nil + return } type custRT struct {