Skip to content

Commit

Permalink
support writing the same Frame to multiple nodes (#66) (#72)
Browse files Browse the repository at this point in the history
  • Loading branch information
aler9 authored Aug 18, 2023
1 parent 140deff commit 49e386d
Show file tree
Hide file tree
Showing 9 changed files with 216 additions and 82 deletions.
3 changes: 2 additions & 1 deletion endpoint_broadcast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ func TestEndpointBroadcast(t *testing.T) {
SystemStatus: 2,
MavlinkVersion: 1,
}
node.WriteMessageAll(msg)
err := node.WriteMessageAll(msg)
require.NoError(t, err)

fr, err := rw.Read()
require.NoError(t, err)
Expand Down
6 changes: 4 additions & 2 deletions endpoint_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,15 @@ func TestEndpointClient(t *testing.T) {
}, evt)

for i := 0; i < 3; i++ {
node.WriteMessageAll(&MessageHeartbeat{
err := node.WriteMessageAll(&MessageHeartbeat{
Type: 1,
Autopilot: 2,
BaseMode: 3,
CustomMode: 6,
SystemStatus: 4,
MavlinkVersion: 5,
})
require.NoError(t, err)

evt = <-node.Events()
require.Equal(t, &EventFrame{
Expand Down Expand Up @@ -214,14 +215,15 @@ func TestEndpointClientIdleTimeout(t *testing.T) {
Channel: evt.(*EventChannelOpen).Channel,
}, evt)

node.WriteMessageAll(&MessageHeartbeat{
err = node.WriteMessageAll(&MessageHeartbeat{
Type: 1,
Autopilot: 2,
BaseMode: 3,
CustomMode: 6,
SystemStatus: 4,
MavlinkVersion: 5,
})
require.NoError(t, err)

select {
case <-closed:
Expand Down
3 changes: 2 additions & 1 deletion endpoint_custom_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ func TestEndpointCustom(t *testing.T) {
SystemStatus: 2,
MavlinkVersion: 1,
}
node.WriteMessageAll(msg)
err := node.WriteMessageAll(msg)
require.NoError(t, err)

fr, err := rw.Read()
require.NoError(t, err)
Expand Down
9 changes: 6 additions & 3 deletions endpoint_serial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,15 @@ func TestEndpointSerial(t *testing.T) {
Channel: evt.(*EventFrame).Channel,
}, evt)

node.WriteMessageAll(&MessageHeartbeat{
err := node.WriteMessageAll(&MessageHeartbeat{
Type: 6,
Autopilot: 5,
BaseMode: 4,
CustomMode: 3,
SystemStatus: 2,
MavlinkVersion: 1,
})
require.NoError(t, err)
}

<-done
Expand Down Expand Up @@ -251,14 +252,15 @@ func TestEndpointSerialReconnect(t *testing.T) {
Channel: evt.(*EventFrame).Channel,
}, evt)

node.WriteMessageAll(&MessageHeartbeat{
err = node.WriteMessageAll(&MessageHeartbeat{
Type: 6,
Autopilot: 5,
BaseMode: 4,
CustomMode: 3,
SystemStatus: 2,
MavlinkVersion: 1,
})
require.NoError(t, err)

evt = <-node.Events()
require.Equal(t, &EventChannelClose{
Expand All @@ -270,14 +272,15 @@ func TestEndpointSerialReconnect(t *testing.T) {
Channel: evt.(*EventChannelOpen).Channel,
}, evt)

node.WriteMessageAll(&MessageHeartbeat{
err = node.WriteMessageAll(&MessageHeartbeat{
Type: 7,
Autopilot: 5,
BaseMode: 4,
CustomMode: 3,
SystemStatus: 2,
MavlinkVersion: 1,
})
require.NoError(t, err)

<-done
}
3 changes: 2 additions & 1 deletion endpoint_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ func TestEndpointServer(t *testing.T) {
SystemStatus: 2,
MavlinkVersion: 1,
}
node.WriteMessageAll(msg)
err = node.WriteMessageAll(msg)
require.NoError(t, err)

fr, err := rw.Read()
require.NoError(t, err)
Expand Down
145 changes: 86 additions & 59 deletions node.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,30 +266,17 @@ outer:
if _, ok := n.channels[req.ch]; !ok {
continue
}

var err error
req.what, err = n.encodeMessage(req.what)
if err == nil {
req.ch.write(req.what)
}
req.ch.write(req.what)

case what := <-n.chWriteAll:
var err error
what, err = n.encodeMessage(what)
if err == nil {
for ch := range n.channels {
ch.write(what)
}
for ch := range n.channels {
ch.write(what)
}

case req := <-n.chWriteExcept:
var err error
req.what, err = n.encodeMessage(req.what)
if err == nil {
for ch := range n.channels {
if ch != req.except {
ch.write(req.what)
}
for ch := range n.channels {
if ch != req.except {
ch.write(req.what)
}
}

Expand Down Expand Up @@ -322,7 +309,7 @@ outer:
// FixFrame recomputes the Frame checksum and signature.
// This can be called on Frames whose content has been edited.
func (n *Node) FixFrame(fr frame.Frame) error {
_, err := n.encodeMessage(fr)
err := n.encodeFrame(fr)
if err != nil {
return err
}
Expand Down Expand Up @@ -352,49 +339,47 @@ func (n *Node) FixFrame(fr frame.Frame) error {
return nil
}

// encode messages once before sending them to the channel's frame.ReadWriter.
func (n *Node) encodeMessage(what interface{}) (interface{}, error) {
switch twhat := what.(type) {
case message.Message:
if _, ok := twhat.(*message.MessageRaw); !ok {
if n.dialectRW == nil {
return nil, fmt.Errorf("dialect is nil")
}
func (n *Node) encodeFrame(fr frame.Frame) error {
if _, ok := fr.GetMessage().(*message.MessageRaw); !ok {
if n.dialectRW == nil {
return fmt.Errorf("dialect is nil")
}

mp := n.dialectRW.GetMessage(twhat.GetID())
if mp == nil {
return nil, fmt.Errorf("message is not in the dialect")
}
mp := n.dialectRW.GetMessage(fr.GetMessage().GetID())
if mp == nil {
return fmt.Errorf("message is not in the dialect")
}

msgRaw := mp.Write(twhat, n.conf.OutVersion == V2)
_, isV2 := fr.(*frame.V2Frame)
msgRaw := mp.Write(fr.GetMessage(), isV2)

return msgRaw, nil
switch fr := fr.(type) {
case *frame.V1Frame:
fr.Message = msgRaw
case *frame.V2Frame:
fr.Message = msgRaw
}
}

case frame.Frame:
if _, ok := twhat.GetMessage().(*message.MessageRaw); !ok {
if n.dialectRW == nil {
return nil, fmt.Errorf("dialect is nil")
}

mp := n.dialectRW.GetMessage(twhat.GetMessage().GetID())
if mp == nil {
return nil, fmt.Errorf("message is not in the dialect")
}
return nil
}

_, isV2 := twhat.(*frame.V2Frame)
msgRaw := mp.Write(twhat.GetMessage(), isV2)
func (n *Node) encodeMessage(msg message.Message) (message.Message, error) {
if _, ok := msg.(*message.MessageRaw); !ok {
if n.dialectRW == nil {
return nil, fmt.Errorf("dialect is nil")
}

switch ff := twhat.(type) {
case *frame.V1Frame:
ff.Message = msgRaw
case *frame.V2Frame:
ff.Message = msgRaw
}
mp := n.dialectRW.GetMessage(msg.GetID())
if mp == nil {
return nil, fmt.Errorf("message is not in the dialect")
}

msgRaw := mp.Write(msg, n.conf.OutVersion == V2)
return msgRaw, nil
}

return what, nil
return msg, nil
}

// Events returns a channel from which receiving events. Possible events are:
Expand All @@ -411,57 +396,99 @@ func (n *Node) Events() chan Event {
}

// WriteMessageTo writes a message to given channel.
func (n *Node) WriteMessageTo(channel *Channel, m message.Message) {
func (n *Node) WriteMessageTo(channel *Channel, m message.Message) error {
m, err := n.encodeMessage(m)
if err != nil {
return err
}

select {
case n.chWriteTo <- writeToReq{channel, m}:
case <-n.terminate:
}

return nil
}

// WriteMessageAll writes a message to all channels.
func (n *Node) WriteMessageAll(m message.Message) {
func (n *Node) WriteMessageAll(m message.Message) error {
m, err := n.encodeMessage(m)
if err != nil {
return err
}

select {
case n.chWriteAll <- m:
case <-n.terminate:
}

return nil
}

// WriteMessageExcept writes a message to all channels except specified channel.
func (n *Node) WriteMessageExcept(exceptChannel *Channel, m message.Message) {
func (n *Node) WriteMessageExcept(exceptChannel *Channel, m message.Message) error {
m, err := n.encodeMessage(m)
if err != nil {
return err
}

select {
case n.chWriteExcept <- writeExceptReq{exceptChannel, m}:
case <-n.terminate:
}

return nil
}

// WriteFrameTo writes a frame to given channel.
// This function is intended only for routing pre-existing frames to other nodes,
// since all frame fields must be filled manually.
func (n *Node) WriteFrameTo(channel *Channel, fr frame.Frame) {
func (n *Node) WriteFrameTo(channel *Channel, fr frame.Frame) error {
err := n.encodeFrame(fr)
if err != nil {
return err
}

select {
case n.chWriteTo <- writeToReq{channel, fr}:
case <-n.terminate:
}

return nil
}

// WriteFrameAll writes a frame to all channels.
// This function is intended only for routing pre-existing frames to other nodes,
// since all frame fields must be filled manually.
func (n *Node) WriteFrameAll(fr frame.Frame) {
func (n *Node) WriteFrameAll(fr frame.Frame) error {
err := n.encodeFrame(fr)
if err != nil {
return err
}

select {
case n.chWriteAll <- fr:
case <-n.terminate:
}

return nil
}

// WriteFrameExcept writes a frame to all channels except specified channel.
// This function is intended only for routing pre-existing frames to other nodes,
// since all frame fields must be filled manually.
func (n *Node) WriteFrameExcept(exceptChannel *Channel, fr frame.Frame) {
func (n *Node) WriteFrameExcept(exceptChannel *Channel, fr frame.Frame) error {
err := n.encodeFrame(fr)
if err != nil {
return err
}

select {
case n.chWriteExcept <- writeExceptReq{exceptChannel, fr}:
case <-n.terminate:
}

return nil
}

func (n *Node) pushEvent(evt Event) {
Expand Down
2 changes: 1 addition & 1 deletion node_heartbeat.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func (h *nodeHeartbeat) run() {
m.Elem().FieldByName("CustomMode").SetUint(0)
m.Elem().FieldByName("SystemStatus").SetUint(4) // MAV_STATE_ACTIVE
m.Elem().FieldByName("MavlinkVersion").SetUint(uint64(h.n.conf.Dialect.Version))
h.n.WriteMessageAll(m.Interface().(message.Message))
h.n.WriteMessageAll(m.Interface().(message.Message)) //nolint:errcheck

case <-h.terminate:
return
Expand Down
2 changes: 1 addition & 1 deletion node_streamrequest.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ func (sr *nodeStreamRequest) onEventFrame(evt *EventFrame) {
m.Elem().FieldByName("ReqStreamId").SetUint(uint64(stream))
m.Elem().FieldByName("ReqMessageRate").SetUint(uint64(sr.n.conf.StreamRequestFrequency))
m.Elem().FieldByName("StartStop").SetUint(uint64(1))
sr.n.WriteMessageTo(evt.Channel, m.Interface().(message.Message))
sr.n.WriteMessageTo(evt.Channel, m.Interface().(message.Message)) //nolint:errcheck
}

sr.n.pushEvent(&EventStreamRequested{
Expand Down
Loading

0 comments on commit 49e386d

Please sign in to comment.