Skip to content

Commit

Permalink
emit ChannelClose and ChannelOpen when a client connects and disconne…
Browse files Browse the repository at this point in the history
…cts (#71)
  • Loading branch information
aler9 authored Aug 18, 2023
1 parent 51190a4 commit 140deff
Show file tree
Hide file tree
Showing 17 changed files with 486 additions and 504 deletions.
8 changes: 3 additions & 5 deletions channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@ import (
)

const (
// this is low in order to avoid accumulating messages
// when a channel is reconnecting
writeBufferSize = 8
writeBufferSize = 64
)

func randomByte() (byte, error) {
Expand Down Expand Up @@ -93,12 +91,12 @@ func (ch *Channel) close() {

func (ch *Channel) start() {
ch.running = true
ch.n.channelsWg.Add(1)
ch.n.wg.Add(1)
go ch.run()
}

func (ch *Channel) run() {
defer ch.n.channelsWg.Done()
defer ch.n.wg.Done()

readerDone := make(chan struct{})
go ch.runReader(readerDone)
Expand Down
20 changes: 10 additions & 10 deletions channel_accepter.go → channel_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,32 @@ import (
"fmt"
)

type channelAccepter struct {
type channelProvider struct {
n *Node
eca endpointChannelAccepter
eca endpointChannelProvider
}

func newChannelAccepter(n *Node, eca endpointChannelAccepter) (*channelAccepter, error) {
return &channelAccepter{
func newChannelProvider(n *Node, eca endpointChannelProvider) (*channelProvider, error) {
return &channelProvider{
n: n,
eca: eca,
}, nil
}

func (ca *channelAccepter) close() {
func (ca *channelProvider) close() {
ca.eca.close()
}

func (ca *channelAccepter) start() {
ca.n.channelAcceptersWg.Add(1)
func (ca *channelProvider) start() {
ca.n.wg.Add(1)
go ca.run()
}

func (ca *channelAccepter) run() {
defer ca.n.channelAcceptersWg.Done()
func (ca *channelProvider) run() {
defer ca.n.wg.Done()

for {
label, rwc, err := ca.eca.accept()
label, rwc, err := ca.eca.provide()
if err != nil {
if err != errTerminated {
panic("errTerminated is the only error allowed here")
Expand Down
8 changes: 4 additions & 4 deletions endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type Endpoint interface {

// a endpoint must also implement one of the following:
// - endpointChannelSingle
// - endpointChannelAccepter
// - endpointChannelProvider

// endpointChannelSingle is an endpoint that provides a single channel.
// Read() must not return any error unless Close() is called.
Expand All @@ -28,9 +28,9 @@ type endpointChannelSingle interface {
io.ReadWriteCloser
}

// endpointChannelAccepter is an endpoint that provides multiple channels.
type endpointChannelAccepter interface {
// endpointChannelProvider is an endpoint that provides multiple channels.
type endpointChannelProvider interface {
Endpoint
close()
accept() (string, io.ReadWriteCloser, error)
provide() (string, io.ReadWriteCloser, error)
}
2 changes: 1 addition & 1 deletion endpoint_broadcast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func TestEndpointBroadcast(t *testing.T) {
})
require.NoError(t, err)

for i := 0; i < 3; i++ {
for i := 0; i < 3; i++ { //nolint:dupl
msg := &MessageHeartbeat{
Type: 1,
Autopilot: 2,
Expand Down
26 changes: 19 additions & 7 deletions endpoint_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"io"
"net"

"github.com/bluenviron/gomavlib/v2/pkg/autoreconnector"
"github.com/bluenviron/gomavlib/v2/pkg/reconnector"
"github.com/bluenviron/gomavlib/v2/pkg/timednetconn"
)

Expand Down Expand Up @@ -56,8 +56,8 @@ func (conf EndpointUDPClient) init(node *Node) (Endpoint, error) {
}

type endpointClient struct {
conf endpointClientConf
io.ReadWriteCloser
conf endpointClientConf
reconnector *reconnector.Reconnector
}

func initEndpointClient(node *Node, conf endpointClientConf) (Endpoint, error) {
Expand All @@ -68,18 +68,17 @@ func initEndpointClient(node *Node, conf endpointClientConf) (Endpoint, error) {

t := &endpointClient{
conf: conf,
ReadWriteCloser: autoreconnector.New(
reconnector: reconnector.New(
func(ctx context.Context) (io.ReadWriteCloser, error) {
// solve address and connect
// in UDP, the only possible error is a DNS failure
// in TCP, the handshake must be completed
network := func() string {
if conf.isUDP() {
return "udp4"
}
return "tcp4"
}()

// in UDP, the only possible error is a DNS failure
// in TCP, the handshake must be completed
timedContext, timedContextClose := context.WithTimeout(ctx, node.conf.ReadTimeout)
nconn, err := (&net.Dialer{}).DialContext(timedContext, network, conf.getAddress())
timedContextClose()
Expand All @@ -106,6 +105,19 @@ func (t *endpointClient) Conf() EndpointConf {
return t.conf
}

func (t *endpointClient) close() {
t.reconnector.Close()
}

func (t *endpointClient) provide() (string, io.ReadWriteCloser, error) {
conn, ok := t.reconnector.Reconnect()
if !ok {
return "", nil, errTerminated
}

return t.label(), conn, nil
}

func (t *endpointClient) label() string {
return fmt.Sprintf("%s:%s", func() string {
if t.conf.isUDP() {
Expand Down
17 changes: 1 addition & 16 deletions endpoint_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
"github.com/bluenviron/gomavlib/v2/pkg/frame"
)

var _ endpointChannelSingle = (*endpointClient)(nil)
var _ endpointChannelProvider = (*endpointClient)(nil)

func TestEndpointClient(t *testing.T) {
for _, ca := range []string{"tcp", "udp"} {
Expand All @@ -32,15 +32,11 @@ func TestEndpointClient(t *testing.T) {
}
defer ln.Close()

connected := make(chan struct{})

go func() {
conn, err := ln.Accept()
require.NoError(t, err)
defer conn.Close()

close(connected)

dialectRW, err := dialect.NewReadWriter(testDialect)
require.NoError(t, err)

Expand Down Expand Up @@ -104,12 +100,6 @@ func TestEndpointClient(t *testing.T) {
Channel: evt.(*EventChannelOpen).Channel,
}, evt)

if ca == "tcp" {
<-connected
} else {
time.Sleep(100 * time.Millisecond) // wait UDP channel to reach connected status
}

for i := 0; i < 3; i++ {
node.WriteMessageAll(&MessageHeartbeat{
Type: 1,
Expand Down Expand Up @@ -152,16 +142,13 @@ func TestEndpointClientIdleTimeout(t *testing.T) {
require.NoError(t, err)
defer ln.Close()

connected := make(chan struct{})
closed := make(chan struct{})
reconnected := make(chan struct{})

go func() {
conn, err := ln.Accept()
require.NoError(t, err)

close(connected)

dialectRW, err := dialect.NewReadWriter(testDialect)
require.NoError(t, err)

Expand Down Expand Up @@ -227,8 +214,6 @@ func TestEndpointClientIdleTimeout(t *testing.T) {
Channel: evt.(*EventChannelOpen).Channel,
}, evt)

<-connected

node.WriteMessageAll(&MessageHeartbeat{
Type: 1,
Autopilot: 2,
Expand Down
41 changes: 17 additions & 24 deletions endpoint_custom_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package gomavlib

import (
"bytes"
"errors"
"io"
"testing"
Expand All @@ -14,39 +13,39 @@ import (

var _ endpointChannelSingle = (*endpointCustom)(nil)

type dummyEndpoint struct {
type dummyReadWriter struct {
chOut chan []byte
chIn chan []byte
chReadErr chan struct{}
}

func newDummyEndpoint() *dummyEndpoint {
return &dummyEndpoint{
func newDummyReadWriterPair() (*dummyReadWriter, *dummyReadWriter) {
one := &dummyReadWriter{
chOut: make(chan []byte),
chIn: make(chan []byte),
chReadErr: make(chan struct{}),
}
}

func (e *dummyEndpoint) simulateReadError() {
close(e.chReadErr)
}
two := &dummyReadWriter{
chOut: one.chIn,
chIn: one.chOut,
chReadErr: make(chan struct{}),
}

func (e *dummyEndpoint) push(buf []byte) {
e.chOut <- buf
return one, two
}

func (e *dummyEndpoint) pull() []byte {
return <-e.chIn
func (e *dummyReadWriter) simulateReadError() {
close(e.chReadErr)
}

func (e *dummyEndpoint) Close() error {
func (e *dummyReadWriter) Close() error {
close(e.chOut)
close(e.chIn)
return nil
}

func (e *dummyEndpoint) Read(p []byte) (int, error) {
func (e *dummyReadWriter) Read(p []byte) (int, error) {
select {
case buf, ok := <-e.chOut:
if !ok {
Expand All @@ -58,19 +57,19 @@ func (e *dummyEndpoint) Read(p []byte) (int, error) {
}
}

func (e *dummyEndpoint) Write(p []byte) (int, error) {
func (e *dummyReadWriter) Write(p []byte) (int, error) {
e.chIn <- p
return len(p), nil
}

func TestEndpointCustom(t *testing.T) {
de := newDummyEndpoint()
remote, local := newDummyReadWriterPair()

node, err := NewNode(NodeConf{
Dialect: testDialect,
OutVersion: V2,
OutSystemID: 10,
Endpoints: []EndpointConf{EndpointCustom{de}},
Endpoints: []EndpointConf{EndpointCustom{remote}},
HeartbeatDisable: true,
})
require.NoError(t, err)
Expand All @@ -84,10 +83,8 @@ func TestEndpointCustom(t *testing.T) {
dialectRW, err := dialect.NewReadWriter(testDialect)
require.NoError(t, err)

var buf bytes.Buffer

rw, err := frame.NewReadWriter(frame.ReadWriterConf{
ReadWriter: &buf,
ReadWriter: local,
DialectRW: dialectRW,
OutVersion: frame.V2,
OutSystemID: 11,
Expand All @@ -105,8 +102,6 @@ func TestEndpointCustom(t *testing.T) {
}
err = rw.WriteMessage(msg)
require.NoError(t, err)
de.push(buf.Bytes())
buf.Reset()

evt = <-node.Events()
require.Equal(t, &EventFrame{
Expand All @@ -130,8 +125,6 @@ func TestEndpointCustom(t *testing.T) {
}
node.WriteMessageAll(msg)

buf2 := de.pull()
buf.Write(buf2)
fr, err := rw.Read()
require.NoError(t, err)
require.Equal(t, &frame.V2Frame{
Expand Down
21 changes: 15 additions & 6 deletions endpoint_serial.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (

"github.com/tarm/serial"

"github.com/bluenviron/gomavlib/v2/pkg/autoreconnector"
"github.com/bluenviron/gomavlib/v2/pkg/reconnector"
)

var serialOpenFunc = func(device string, baud int) (io.ReadWriteCloser, error) {
Expand All @@ -26,8 +26,8 @@ type EndpointSerial struct {
}

type endpointSerial struct {
conf EndpointConf
io.ReadWriteCloser
conf EndpointConf
reconnector *reconnector.Reconnector
}

func (conf EndpointSerial) init(_ *Node) (Endpoint, error) {
Expand All @@ -40,7 +40,7 @@ func (conf EndpointSerial) init(_ *Node) (Endpoint, error) {

t := &endpointSerial{
conf: conf,
ReadWriteCloser: autoreconnector.New(
reconnector: reconnector.New(
func(ctx context.Context) (io.ReadWriteCloser, error) {
return serialOpenFunc(conf.Device, conf.Baud)
},
Expand All @@ -56,6 +56,15 @@ func (t *endpointSerial) Conf() EndpointConf {
return t.conf
}

func (t *endpointSerial) label() string {
return "serial"
func (t *endpointSerial) close() {
t.reconnector.Close()
}

func (t *endpointSerial) provide() (string, io.ReadWriteCloser, error) {
conn, ok := t.reconnector.Reconnect()
if !ok {
return "", nil, errTerminated
}

return "serial", conn, nil
}
Loading

0 comments on commit 140deff

Please sign in to comment.