Skip to content

Commit

Permalink
feat: improve ws handling
Browse files Browse the repository at this point in the history
  • Loading branch information
denopink committed Apr 2, 2024
1 parent 3f3687f commit 3dd1841
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 102 deletions.
2 changes: 2 additions & 0 deletions cmd/xmidt-agent/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ func provideCredentials(in credsIn) (*credentials.Credentials, error) {

opts := []credentials.Option{
credentials.URL(in.Creds.URL),
// enabling `Required` allows the xmidt-agent to send connect events for auth related errors
credentials.Required(),
credentials.HTTPClient(client),
credentials.MacAddress(in.ID.DeviceID),
credentials.SerialNumber(in.ID.SerialNumber),
Expand Down
14 changes: 12 additions & 2 deletions internal/websocket/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ package websocket_test

import (
"context"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
"sync"
Expand All @@ -23,6 +25,8 @@ import (
)

func TestEndToEnd(t *testing.T) {
var finished bool

assert := assert.New(t)
require := require.New(t)

Expand All @@ -44,14 +48,19 @@ func TestEndToEnd(t *testing.T) {
require.NoError(err)

mt, got, err := c.Read(ctx)
// server will halt until the websocket closes resulting in a EOF
if finished && errors.Is(err, io.EOF) {
return
}

require.NoError(err)
require.Equal(websocket.MessageBinary, mt)
require.NotEmpty(got)

err = wrp.NewDecoderBytes(got, wrp.Msgpack).Decode(&msg)
require.NoError(err)
require.Equal(wrp.SimpleEventMessageType, msg.Type)
require.Equal("server", msg.Source)
require.Equal("client", msg.Source)

c.Close(websocket.StatusNormalClosure, "")
}))
Expand Down Expand Up @@ -138,6 +147,7 @@ func TestEndToEnd(t *testing.T) {
}
time.Sleep(10 * time.Millisecond)
got.Stop()
finished = true
}

func TestEndToEndBadData(t *testing.T) {
Expand Down Expand Up @@ -256,7 +266,7 @@ func TestEndToEndConnectionIssues(t *testing.T) {
require.NoError(err)
defer c.CloseNow()

ctx, cancel := context.WithTimeout(r.Context(), 2000000*time.Millisecond)
ctx, cancel := context.WithTimeout(r.Context(), 200*time.Millisecond)
defer cancel()

msg := wrp.Message{
Expand Down
234 changes: 134 additions & 100 deletions internal/websocket/ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ package websocket
import (
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"sync"
Expand Down Expand Up @@ -101,7 +101,10 @@ type Websocket struct {
wg sync.WaitGroup
shutdown context.CancelFunc

conn *nhws.Conn
policy retry.Policy
decoder wrp.Decoder
encoder wrp.Encoder
conn *nhws.Conn
}

// Option is a functional option type for WS.
Expand Down Expand Up @@ -152,13 +155,18 @@ func (ws *Websocket) Start() {

var ctx context.Context
ctx, ws.shutdown = context.WithCancel(context.Background())
// Init retry policy, but it'll be reset on recurring successful connections.
ws.policy = ws.retryPolicyFactory.NewPolicy(ctx)
ws.decoder = wrp.NewDecoder(nil, wrp.Msgpack)

go ws.run(ctx)
go ws.readPump(ctx)
}

// Stop stops the websocket connection.
func (ws *Websocket) Stop() {
ws.m.Lock()
// Avoid the overhead of the close handshake.
ws.conn.Close(nhws.StatusNormalClosure, "")
shutdown := ws.shutdown
ws.m.Unlock()

Expand All @@ -183,117 +191,41 @@ func (ws *Websocket) Send(ctx context.Context, msg wrp.Message) error {
return err
}

func (ws *Websocket) run(ctx context.Context) {
func (ws *Websocket) readPump(ctx context.Context) {
ws.wg.Add(1)
defer ws.wg.Done()

decoder := wrp.NewDecoder(nil, wrp.Msgpack)
encoder := wrp.NewEncoder(nil, wrp.Msgpack)
mode := ws.nextMode(ipv4)

policy := ws.retryPolicyFactory.NewPolicy(ctx)

reconnect := true
for {
var next time.Duration

mode = ws.nextMode(mode)
cEvent := event.Connect{
Started: ws.nowFunc(),
Mode: mode.ToEvent(),
var dialErr error
if reconnect {
dialErr = ws.dial(ctx, mode)
}

// If auth fails, then continue with openfail xmidt connection
ws.credDecorator(ws.additionalHeaders)

conn, _, dialErr := ws.dial(ctx, mode) //nolint:bodyclose
cEvent.At = ws.nowFunc()

if dialErr == nil {
ws.connectListeners.Visit(func(l event.ConnectListener) {
l.OnConnect(cEvent)
})

// Reset the retry policy on a successful connection.
policy = ws.retryPolicyFactory.NewPolicy(ctx)

// Store the connection so writing can take place.
ws.m.Lock()
ws.conn = conn
ws.m.Unlock()

// Read loop
for {
var msg wrp.Message
typ, reader, err := ws.conn.Reader(ctx)
if err == nil {
if typ != nhws.MessageBinary {
err = ErrInvalidMsgType
} else {
decoder.Reset(reader)
err = decoder.Decode(&msg)
}
}

msg, err := ws.readMsg(ctx, mode)
// If a reconnect was attempted but failed, ErrClosed will be found
// in this error list and a reconnect should be attempted again.
reconnect = errors.Is(err, ErrClosed)
if err != nil {
ws.m.Lock()
ws.conn = nil
ws.m.Unlock()

// The websocket gave us an unexpected message, or a message
// that could not be decoded. Close & reconnect.
_ = conn.Close(nhws.StatusUnsupportedData, limit(err.Error()))

dEvent := event.Disconnect{
At: ws.nowFunc(),
Err: err,
}
ws.disconnectListeners.Visit(func(l event.DisconnectListener) {
l.OnDisconnect(dEvent)
})

break
}

ws.msgListeners.Visit(func(l event.MsgListener) {
l.OnMessage(msg)
l.OnMessage(*msg)
})

// TODO - This section simply sends back the received wrp msg as a respond to the client's request. This will be replaced
var frameContents []byte

// if the request was in a format other than Msgpack, or if the caller did not pass
// Contents, then do the encoding here.
encoder.ResetBytes(&frameContents)
err = encoder.Encode(msg)
encoder.ResetBytes(&emptyBuffer)
if err != nil {
ws.disconnectListeners.Visit(func(l event.DisconnectListener) {
l.OnDisconnect(event.Disconnect{
At: ws.nowFunc(),
Err: fmt.Errorf("xmidt-agent failed to response to wrp message: %s", err),
})
})

continue
}

ws.conn.Write(ctx, nhws.MessageBinary, frameContents)
}
}

if ws.once {
return
}

next, _ = policy.Next()

if dialErr != nil {
cEvent.Err = dialErr
cEvent.RetryingAt = ws.nowFunc().Add(next)
ws.connectListeners.Visit(func(l event.ConnectListener) {
l.OnConnect(cEvent)
})
}
mode = ws.nextMode(mode)
next, _ := ws.policy.Next()

select {
case <-time.After(next):
Expand All @@ -303,15 +235,121 @@ func (ws *Websocket) run(ctx context.Context) {
}
}

func (ws *Websocket) dial(ctx context.Context, mode ipMode) (*nhws.Conn, *http.Response, error) {
func (ws *Websocket) readMsg(ctx context.Context, mode ipMode) (msg *wrp.Message, err error) {
defer func() {
if err != nil {
// The websocket either failed to read, gave us an unexpected message or a message
// that could not be decoded. Attempt to reconnect.
// If the reconnect fails, ErrClosed will be added to the error list allowing downstream to attempt a reconnect.
// Otherwise ErrClosed is not in the error list.
err = errors.Join(err, ws.dial(ctx, mode))
}
}()

var (
typ nhws.MessageType
reader io.Reader
)
typ, reader, err = ws.conn.Reader(ctx)

if err == nil {
if typ != nhws.MessageBinary {
err = ErrInvalidMsgType
} else {
ws.decoder.Reset(reader)
msg = &wrp.Message{}
err = ws.decoder.Decode(msg)
}
}

if err != nil {
// The websocket gave us an unexpected message, or a message
// that could not be decoded. Close and send a disconnect event.
_ = ws.conn.Close(nhws.StatusUnsupportedData, limit(err.Error()))
ws.disconnectListeners.Visit(func(l event.DisconnectListener) {
l.OnDisconnect(event.Disconnect{
At: ws.nowFunc(),
Err: err,
})
})
}

return
}

func (ws *Websocket) dial(ctx context.Context, mode ipMode) (err error) {
var (
conn *nhws.Conn
resp *http.Response
)

ws.m.Lock()
defer ws.m.Unlock()
defer func() {
// Reconnect was successful, store the connection and send connect event.
if err == nil {
if resp.Body != nil {
resp.Body.Close()
}

ws.conn = conn
// Reset the retry policy on a successful connection.
ws.policy = ws.retryPolicyFactory.NewPolicy(ctx)
conn.SetReadLimit(ws.maxMessageBytes)
ws.connectListeners.Visit(func(l event.ConnectListener) {
l.OnConnect(event.Connect{
Started: ws.nowFunc(),
Mode: mode.ToEvent(),
At: ws.nowFunc(),
})
})

return
}

next, _ := ws.policy.Next()
// Send a connect event with the error that caused the failed connection
// (it'll never be an auth error).
ws.connectListeners.Visit(func(l event.ConnectListener) {
l.OnConnect(event.Connect{
Started: ws.nowFunc(),
Mode: mode.ToEvent(),
At: ws.nowFunc(),
Err: err,
RetryingAt: ws.nowFunc().Add(next),
})
})

// Failed to reconnect, add a ErrClosed to the error list allowing downstream to attempt a reconnect.
err = errors.Join(err, ErrClosed)
}()

fetchCtx, cancel := context.WithTimeout(ctx, ws.urlFetchingTimeout)
defer cancel()
url, err := ws.urlFetcher(fetchCtx)

var url string
url, err = ws.urlFetcher(fetchCtx)
if err != nil {
return nil, nil, err
return
}

conn, resp, err := nhws.Dial(ctx, url,
// If auth fails, then continue with an openfail (no themis token) xmidt connection.
// An auth error will never trigger an attempt to reconnect.
if err := ws.credDecorator(ws.additionalHeaders); err != nil {
next, _ := ws.policy.Next()
// Send a connect event with the auth error.
ws.connectListeners.Visit(func(l event.ConnectListener) {
l.OnConnect(event.Connect{
Started: ws.nowFunc(),
Mode: mode.ToEvent(),
At: ws.nowFunc(),
Err: err,
RetryingAt: ws.nowFunc().Add(next),
})
})
}

conn, resp, err = nhws.Dial(ctx, url,
&nhws.DialOptions{
HTTPHeader: ws.additionalHeaders,
HTTPClient: &http.Client{
Expand All @@ -320,12 +358,8 @@ func (ws *Websocket) dial(ctx context.Context, mode ipMode) (*nhws.Conn, *http.R
},
},
)
if err != nil {
return nil, resp, err
}

conn.SetReadLimit(ws.maxMessageBytes)
return conn, resp, nil
return
}

type custRT struct {
Expand Down

0 comments on commit 3dd1841

Please sign in to comment.