Skip to content

Commit

Permalink
discard outgoing messages when an endpoint is reconnecting (#70)
Browse files Browse the repository at this point in the history
  • Loading branch information
aler9 authored Aug 18, 2023
1 parent a025891 commit 51190a4
Show file tree
Hide file tree
Showing 3 changed files with 234 additions and 147 deletions.
207 changes: 122 additions & 85 deletions endpoint_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,58 @@ func TestEndpointClient(t *testing.T) {
ln, err = udp.Listen("udp", addr)
require.NoError(t, err)
}

defer ln.Close()

connected := make(chan struct{})

go func() {
conn, err := ln.Accept()
require.NoError(t, err)
defer conn.Close()

close(connected)

dialectRW, err := dialect.NewReadWriter(testDialect)
require.NoError(t, err)

rw, err := frame.NewReadWriter(frame.ReadWriterConf{
ReadWriter: conn,
DialectRW: dialectRW,
OutVersion: frame.V2,
OutSystemID: 11,
})
require.NoError(t, err)

for i := 0; i < 3; i++ {
fr, err := rw.Read()
require.NoError(t, err)
require.Equal(t, &frame.V2Frame{
SequenceID: byte(i),
SystemID: 10,
ComponentID: 1,
Message: &MessageHeartbeat{
Type: 1,
Autopilot: 2,
BaseMode: 3,
CustomMode: 6,
SystemStatus: 4,
MavlinkVersion: 5,
},
Checksum: fr.GetChecksum(),
}, fr)

err = rw.WriteMessage(&MessageHeartbeat{
Type: 6,
Autopilot: 5,
BaseMode: 4,
CustomMode: 3,
SystemStatus: 2,
MavlinkVersion: 1,
})
require.NoError(t, err)
}
}()

var e EndpointConf
if ca == "tcp" {
e = EndpointTCPClient{"127.0.0.1:5601"}
Expand All @@ -55,65 +104,37 @@ func TestEndpointClient(t *testing.T) {
Channel: evt.(*EventChannelOpen).Channel,
}, evt)

var rw *frame.ReadWriter
if ca == "tcp" {
<-connected
} else {
time.Sleep(100 * time.Millisecond) // wait UDP channel to reach connected status
}

for i := 0; i < 3; i++ {
msg := &MessageHeartbeat{
node.WriteMessageAll(&MessageHeartbeat{
Type: 1,
Autopilot: 2,
BaseMode: 3,
CustomMode: 6,
SystemStatus: 4,
MavlinkVersion: 5,
}
node.WriteMessageAll(msg)

if i == 0 {
conn, err := ln.Accept()
require.NoError(t, err)
defer conn.Close()

dialectRW, err := dialect.NewReadWriter(testDialect)
require.NoError(t, err)

rw, err = frame.NewReadWriter(frame.ReadWriterConf{
ReadWriter: conn,
DialectRW: dialectRW,
OutVersion: frame.V2,
OutSystemID: 11,
})
require.NoError(t, err)
}

fr, err := rw.Read()
require.NoError(t, err)
require.Equal(t, &frame.V2Frame{
SequenceID: byte(i),
SystemID: 10,
ComponentID: 1,
Message: msg,
Checksum: fr.GetChecksum(),
}, fr)

msg = &MessageHeartbeat{
Type: 6,
Autopilot: 5,
BaseMode: 4,
CustomMode: 3,
SystemStatus: 2,
MavlinkVersion: 1,
}
err = rw.WriteMessage(msg)
require.NoError(t, err)
})

evt = <-node.Events()
require.Equal(t, &EventFrame{
Frame: &frame.V2Frame{
SequenceID: byte(i),
SystemID: 11,
ComponentID: 1,
Message: msg,
Checksum: evt.(*EventFrame).Frame.GetChecksum(),
Message: &MessageHeartbeat{
Type: 6,
Autopilot: 5,
BaseMode: 4,
CustomMode: 3,
SystemStatus: 2,
MavlinkVersion: 1,
},
Checksum: evt.(*EventFrame).Frame.GetChecksum(),
},
Channel: evt.(*EventFrame).Channel,
}, evt)
Expand All @@ -129,9 +150,60 @@ func TestEndpointClientIdleTimeout(t *testing.T) {
var err error
ln, err = net.Listen("tcp", "127.0.0.1:5603")
require.NoError(t, err)

defer ln.Close()

connected := make(chan struct{})
closed := make(chan struct{})
reconnected := make(chan struct{})

go func() {
conn, err := ln.Accept()
require.NoError(t, err)

close(connected)

dialectRW, err := dialect.NewReadWriter(testDialect)
require.NoError(t, err)

rw, err := frame.NewReadWriter(frame.ReadWriterConf{
ReadWriter: conn,
DialectRW: dialectRW,
OutVersion: frame.V2,
OutSystemID: 11,
})
require.NoError(t, err)

fr, err := rw.Read()
require.NoError(t, err)
require.Equal(t, &frame.V2Frame{
SequenceID: 0,
SystemID: 10,
ComponentID: 1,
Message: &MessageHeartbeat{
Type: 1,
Autopilot: 2,
BaseMode: 3,
CustomMode: 6,
SystemStatus: 4,
MavlinkVersion: 5,
},
Checksum: fr.GetChecksum(),
}, fr)

_, err = rw.Read()
require.Equal(t, io.EOF, err)
conn.Close()

close(closed)

// the client reconnects to the server due to autoReconnector
conn, err = ln.Accept()
require.NoError(t, err)
conn.Close()

close(reconnected)
}()

var e EndpointConf
if ca == "tcp" {
e = EndpointTCPClient{"127.0.0.1:5603"}
Expand All @@ -155,59 +227,24 @@ func TestEndpointClientIdleTimeout(t *testing.T) {
Channel: evt.(*EventChannelOpen).Channel,
}, evt)

msg := &MessageHeartbeat{
<-connected

node.WriteMessageAll(&MessageHeartbeat{
Type: 1,
Autopilot: 2,
BaseMode: 3,
CustomMode: 6,
SystemStatus: 4,
MavlinkVersion: 5,
}
node.WriteMessageAll(msg)

conn, err := ln.Accept()
require.NoError(t, err)

dialectRW, err := dialect.NewReadWriter(testDialect)
require.NoError(t, err)

rw, err := frame.NewReadWriter(frame.ReadWriterConf{
ReadWriter: conn,
DialectRW: dialectRW,
OutVersion: frame.V2,
OutSystemID: 11,
})
require.NoError(t, err)

fr, err := rw.Read()
require.NoError(t, err)
require.Equal(t, &frame.V2Frame{
SequenceID: 0,
SystemID: 10,
ComponentID: 1,
Message: msg,
Checksum: fr.GetChecksum(),
}, fr)

closed := make(chan struct{})

go func() {
_, err = rw.Read()
require.Equal(t, io.EOF, err)
conn.Close()
close(closed)
}()

select {
case <-closed:
case <-time.After(1 * time.Second):
t.Errorf("should not happen")
}

// the client reconnects to the server due to autoReconnector
conn, err = ln.Accept()
require.NoError(t, err)
conn.Close()
<-reconnected
})
}
}
9 changes: 9 additions & 0 deletions node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"sync"
"testing"
"time"

"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -96,6 +97,8 @@ func TestNodeCloseInLoop(t *testing.T) {
require.NoError(t, err)
defer node2.Close()

time.Sleep(100 * time.Millisecond) // wait UDP channel to reach connected status

node2.WriteMessageAll(testMessage)

for evt := range node1.Events() {
Expand Down Expand Up @@ -365,6 +368,8 @@ func TestNodeWriteMessageInLoop(t *testing.T) {
require.NoError(t, err)
defer node2.Close()

time.Sleep(100 * time.Millisecond) // wait UDP channel to reach connected status

node2.WriteMessageAll(testMessage)

for evt := range node1.Events() {
Expand Down Expand Up @@ -409,6 +414,8 @@ func TestNodeSignature(t *testing.T) {
require.NoError(t, err)
defer node2.Close()

time.Sleep(100 * time.Millisecond) // wait UDP channel to reach connected status

node2.WriteMessageAll(testMessage)

<-node1.Events()
Expand Down Expand Up @@ -519,6 +526,8 @@ func TestNodeFixFrame(t *testing.T) {
require.NoError(t, err)
defer node2.Close()

time.Sleep(100 * time.Millisecond) // wait UDP channel to reach connected status

fra := &frame.V2Frame{
SequenceID: 13,
SystemID: 15,
Expand Down
Loading

0 comments on commit 51190a4

Please sign in to comment.