diff --git a/endpoint_client_test.go b/endpoint_client_test.go index bc92c4ab6..ecbcba706 100644 --- a/endpoint_client_test.go +++ b/endpoint_client_test.go @@ -30,9 +30,58 @@ func TestEndpointClient(t *testing.T) { ln, err = udp.Listen("udp", addr) require.NoError(t, err) } - defer ln.Close() + connected := make(chan struct{}) + + go func() { + conn, err := ln.Accept() + require.NoError(t, err) + defer conn.Close() + + close(connected) + + dialectRW, err := dialect.NewReadWriter(testDialect) + require.NoError(t, err) + + rw, err := frame.NewReadWriter(frame.ReadWriterConf{ + ReadWriter: conn, + DialectRW: dialectRW, + OutVersion: frame.V2, + OutSystemID: 11, + }) + require.NoError(t, err) + + for i := 0; i < 3; i++ { + fr, err := rw.Read() + require.NoError(t, err) + require.Equal(t, &frame.V2Frame{ + SequenceID: byte(i), + SystemID: 10, + ComponentID: 1, + Message: &MessageHeartbeat{ + Type: 1, + Autopilot: 2, + BaseMode: 3, + CustomMode: 6, + SystemStatus: 4, + MavlinkVersion: 5, + }, + Checksum: fr.GetChecksum(), + }, fr) + + err = rw.WriteMessage(&MessageHeartbeat{ + Type: 6, + Autopilot: 5, + BaseMode: 4, + CustomMode: 3, + SystemStatus: 2, + MavlinkVersion: 1, + }) + require.NoError(t, err) + } + }() + var e EndpointConf if ca == "tcp" { e = EndpointTCPClient{"127.0.0.1:5601"} @@ -55,56 +104,21 @@ func TestEndpointClient(t *testing.T) { Channel: evt.(*EventChannelOpen).Channel, }, evt) - var rw *frame.ReadWriter + if ca == "tcp" { + <-connected + } else { + time.Sleep(100 * time.Millisecond) // wait UDP channel to reach connected status + } for i := 0; i < 3; i++ { - msg := &MessageHeartbeat{ + node.WriteMessageAll(&MessageHeartbeat{ Type: 1, Autopilot: 2, BaseMode: 3, CustomMode: 6, SystemStatus: 4, MavlinkVersion: 5, - } - node.WriteMessageAll(msg) - - if i == 0 { - conn, err := ln.Accept() - require.NoError(t, err) - defer conn.Close() - - dialectRW, err := dialect.NewReadWriter(testDialect) - require.NoError(t, err) - - rw, err = frame.NewReadWriter(frame.ReadWriterConf{ - ReadWriter: conn, - DialectRW: dialectRW, - OutVersion: frame.V2, - OutSystemID: 11, - }) - require.NoError(t, err) - } - - fr, err := rw.Read() - require.NoError(t, err) - require.Equal(t, &frame.V2Frame{ - SequenceID: byte(i), - SystemID: 10, - ComponentID: 1, - Message: msg, - Checksum: fr.GetChecksum(), - }, fr) - - msg = &MessageHeartbeat{ - Type: 6, - Autopilot: 5, - BaseMode: 4, - CustomMode: 3, - SystemStatus: 2, - MavlinkVersion: 1, - } - err = rw.WriteMessage(msg) - require.NoError(t, err) + }) evt = <-node.Events() require.Equal(t, &EventFrame{ @@ -112,8 +126,15 @@ func TestEndpointClient(t *testing.T) { SequenceID: byte(i), SystemID: 11, ComponentID: 1, - Message: msg, - Checksum: evt.(*EventFrame).Frame.GetChecksum(), + Message: &MessageHeartbeat{ + Type: 6, + Autopilot: 5, + BaseMode: 4, + CustomMode: 3, + SystemStatus: 2, + MavlinkVersion: 1, + }, + Checksum: evt.(*EventFrame).Frame.GetChecksum(), }, Channel: evt.(*EventFrame).Channel, }, evt) @@ -129,9 +150,60 @@ func TestEndpointClientIdleTimeout(t *testing.T) { var err error ln, err = net.Listen("tcp", "127.0.0.1:5603") require.NoError(t, err) - defer ln.Close() + connected := make(chan struct{}) + closed := make(chan struct{}) + reconnected := make(chan struct{}) + + go func() { + conn, err := ln.Accept() + require.NoError(t, err) + + close(connected) + + dialectRW, err := dialect.NewReadWriter(testDialect) + require.NoError(t, err) + + rw, err := frame.NewReadWriter(frame.ReadWriterConf{ + ReadWriter: conn, + DialectRW: dialectRW, + OutVersion: frame.V2, + OutSystemID: 11, + }) + require.NoError(t, err) + + fr, err := rw.Read() + require.NoError(t, err) + require.Equal(t, &frame.V2Frame{ + SequenceID: 0, + SystemID: 10, + ComponentID: 1, + Message: &MessageHeartbeat{ + Type: 1, + Autopilot: 2, + BaseMode: 3, + CustomMode: 6, + SystemStatus: 4, + MavlinkVersion: 5, + }, + Checksum: fr.GetChecksum(), + }, fr) + + _, err = rw.Read() + require.Equal(t, io.EOF, err) + conn.Close() + + close(closed) + + // the client reconnects to the server due to autoReconnector + conn, err = ln.Accept() + require.NoError(t, err) + conn.Close() + + close(reconnected) + }() + var e EndpointConf if ca == "tcp" { e = EndpointTCPClient{"127.0.0.1:5603"} @@ -155,48 +227,16 @@ func TestEndpointClientIdleTimeout(t *testing.T) { Channel: evt.(*EventChannelOpen).Channel, }, evt) - msg := &MessageHeartbeat{ + <-connected + + node.WriteMessageAll(&MessageHeartbeat{ Type: 1, Autopilot: 2, BaseMode: 3, CustomMode: 6, SystemStatus: 4, MavlinkVersion: 5, - } - node.WriteMessageAll(msg) - - conn, err := ln.Accept() - require.NoError(t, err) - - dialectRW, err := dialect.NewReadWriter(testDialect) - require.NoError(t, err) - - rw, err := frame.NewReadWriter(frame.ReadWriterConf{ - ReadWriter: conn, - DialectRW: dialectRW, - OutVersion: frame.V2, - OutSystemID: 11, }) - require.NoError(t, err) - - fr, err := rw.Read() - require.NoError(t, err) - require.Equal(t, &frame.V2Frame{ - SequenceID: 0, - SystemID: 10, - ComponentID: 1, - Message: msg, - Checksum: fr.GetChecksum(), - }, fr) - - closed := make(chan struct{}) - - go func() { - _, err = rw.Read() - require.Equal(t, io.EOF, err) - conn.Close() - close(closed) - }() select { case <-closed: @@ -204,10 +244,7 @@ func TestEndpointClientIdleTimeout(t *testing.T) { t.Errorf("should not happen") } - // the client reconnects to the server due to autoReconnector - conn, err = ln.Accept() - require.NoError(t, err) - conn.Close() + <-reconnected }) } } diff --git a/node_test.go b/node_test.go index 40c97eda3..db7ab1abb 100644 --- a/node_test.go +++ b/node_test.go @@ -4,6 +4,7 @@ import ( "bytes" "sync" "testing" + "time" "github.com/stretchr/testify/require" @@ -96,6 +97,8 @@ func TestNodeCloseInLoop(t *testing.T) { require.NoError(t, err) defer node2.Close() + time.Sleep(100 * time.Millisecond) // wait UDP channel to reach connected status + node2.WriteMessageAll(testMessage) for evt := range node1.Events() { @@ -365,6 +368,8 @@ func TestNodeWriteMessageInLoop(t *testing.T) { require.NoError(t, err) defer node2.Close() + time.Sleep(100 * time.Millisecond) // wait UDP channel to reach connected status + node2.WriteMessageAll(testMessage) for evt := range node1.Events() { @@ -409,6 +414,8 @@ func TestNodeSignature(t *testing.T) { require.NoError(t, err) defer node2.Close() + time.Sleep(100 * time.Millisecond) // wait UDP channel to reach connected status + node2.WriteMessageAll(testMessage) <-node1.Events() @@ -519,6 +526,8 @@ func TestNodeFixFrame(t *testing.T) { require.NoError(t, err) defer node2.Close() + time.Sleep(100 * time.Millisecond) // wait UDP channel to reach connected status + fra := &frame.V2Frame{ SequenceID: 13, SystemID: 15, diff --git a/pkg/autoreconnector/auto_reconnector.go b/pkg/autoreconnector/auto_reconnector.go index 172cd5c40..47d394eab 100644 --- a/pkg/autoreconnector/auto_reconnector.go +++ b/pkg/autoreconnector/auto_reconnector.go @@ -12,35 +12,50 @@ import ( var ( reconnectPeriod = 2 * time.Second errTerminated = errors.New("terminated") + errReconnecting = errors.New("reconnecting") +) + +type state int + +const ( + stateInitial state = iota + stateReconnecting + stateConnected + stateTerminated ) type autoReconnector struct { connect func(context.Context) (io.ReadWriteCloser, error) - ctx context.Context - ctxCancel func() - conn io.ReadWriteCloser - connMutex sync.Mutex + mutex sync.Mutex + state state + conn io.ReadWriteCloser + connectCtx context.Context + connectCtxCancel func() } // New returns a io.ReadWriterCloser that implements auto-reconnection. func New( connect func(context.Context) (io.ReadWriteCloser, error), ) io.ReadWriteCloser { - ctx, ctxCancel := context.WithCancel(context.Background()) - - return &autoReconnector{ - connect: connect, - ctx: ctx, - ctxCancel: ctxCancel, + a := &autoReconnector{ + connect: connect, } + + a.resetConnection() + + return a } func (a *autoReconnector) Close() error { - a.ctxCancel() + a.mutex.Lock() + defer a.mutex.Unlock() - a.connMutex.Lock() - defer a.connMutex.Unlock() + a.state = stateTerminated + + if a.connectCtxCancel != nil { + a.connectCtxCancel() + } if a.conn != nil { a.conn.Close() @@ -50,81 +65,107 @@ func (a *autoReconnector) Close() error { return nil } -func (a *autoReconnector) getConnection(reset bool) (io.ReadWriteCloser, bool) { - a.connMutex.Lock() - defer a.connMutex.Unlock() +func (a *autoReconnector) getConnection() (io.ReadWriteCloser, context.Context, error) { + a.mutex.Lock() + defer a.mutex.Unlock() - if a.conn != nil { - if !reset { - return a.conn, true - } + switch a.state { + case stateTerminated: + return nil, nil, errTerminated + + case stateReconnecting: + return nil, a.connectCtx, errReconnecting + default: + return a.conn, nil, nil + } +} + +func (a *autoReconnector) resetConnection() { + a.mutex.Lock() + defer a.mutex.Unlock() + + switch a.state { + case stateTerminated, stateReconnecting: + return + } + + a.state = stateReconnecting + + if a.conn != nil { a.conn.Close() a.conn = nil } - select { - case <-a.ctx.Done(): - return nil, false - default: - } + a.connectCtx, a.connectCtxCancel = context.WithCancel(context.Background()) - for { - var err error - a.conn, err = a.connect(a.ctx) - if err == nil { - select { - case <-a.ctx.Done(): - a.conn.Close() - a.conn = nil - return nil, false - default: + go func() { + for { + newConn, err := a.connect(a.connectCtx) + if err == nil { + a.setConn(newConn) + return } - return a.conn, true + select { + case <-time.After(reconnectPeriod): + case <-a.connectCtx.Done(): + return + } } + }() +} - select { - case <-time.After(reconnectPeriod): - case <-a.ctx.Done(): - return nil, false - } +func (a *autoReconnector) setConn(newConn io.ReadWriteCloser) { + a.mutex.Lock() + defer a.mutex.Unlock() + + if a.state != stateReconnecting { + newConn.Close() + return } + + a.connectCtxCancel() + a.connectCtxCancel = nil + + a.conn = newConn + a.state = stateConnected } func (a *autoReconnector) Read(p []byte) (int, error) { - reset := false - for { - curConn, ok := a.getConnection(reset) - if !ok { - return 0, errTerminated + curConn, connectCtx, err := a.getConnection() + if err == errReconnecting { + <-connectCtx.Done() + continue + } + if err != nil { + return 0, err } n, err := curConn.Read(p) - if err == nil || (err == io.EOF && n > 0) { - return n, err + + if n == 0 { + a.resetConnection() + continue } - reset = true + return n, err } } func (a *autoReconnector) Write(p []byte) (int, error) { - reset := false - - for { - curConn, ok := a.getConnection(reset) - if !ok { - return 0, errTerminated - } + curConn, _, err := a.getConnection() + if err != nil { + return 0, err + } - n, err := curConn.Write(p) + n, err := curConn.Write(p) - if err == nil { - return n, nil - } - - reset = true + if n == 0 { + a.resetConnection() + return 0, errReconnecting } + + return n, err }