diff --git a/endpoint_broadcast_test.go b/endpoint_broadcast_test.go index 4ce21e859..e08f0d88c 100644 --- a/endpoint_broadcast_test.go +++ b/endpoint_broadcast_test.go @@ -99,7 +99,8 @@ func TestEndpointBroadcast(t *testing.T) { SystemStatus: 2, MavlinkVersion: 1, } - node.WriteMessageAll(msg) + err := node.WriteMessageAll(msg) + require.NoError(t, err) fr, err := rw.Read() require.NoError(t, err) diff --git a/endpoint_client_test.go b/endpoint_client_test.go index be005e39a..fa4e106f9 100644 --- a/endpoint_client_test.go +++ b/endpoint_client_test.go @@ -101,7 +101,7 @@ func TestEndpointClient(t *testing.T) { }, evt) for i := 0; i < 3; i++ { - node.WriteMessageAll(&MessageHeartbeat{ + err := node.WriteMessageAll(&MessageHeartbeat{ Type: 1, Autopilot: 2, BaseMode: 3, @@ -109,6 +109,7 @@ func TestEndpointClient(t *testing.T) { SystemStatus: 4, MavlinkVersion: 5, }) + require.NoError(t, err) evt = <-node.Events() require.Equal(t, &EventFrame{ @@ -214,7 +215,7 @@ func TestEndpointClientIdleTimeout(t *testing.T) { Channel: evt.(*EventChannelOpen).Channel, }, evt) - node.WriteMessageAll(&MessageHeartbeat{ + err = node.WriteMessageAll(&MessageHeartbeat{ Type: 1, Autopilot: 2, BaseMode: 3, @@ -222,6 +223,7 @@ func TestEndpointClientIdleTimeout(t *testing.T) { SystemStatus: 4, MavlinkVersion: 5, }) + require.NoError(t, err) select { case <-closed: diff --git a/endpoint_custom_test.go b/endpoint_custom_test.go index 114f7fd0b..bd529c258 100644 --- a/endpoint_custom_test.go +++ b/endpoint_custom_test.go @@ -123,7 +123,8 @@ func TestEndpointCustom(t *testing.T) { SystemStatus: 2, MavlinkVersion: 1, } - node.WriteMessageAll(msg) + err := node.WriteMessageAll(msg) + require.NoError(t, err) fr, err := rw.Read() require.NoError(t, err) diff --git a/endpoint_serial_test.go b/endpoint_serial_test.go index 2925697a3..aa3fc4716 100644 --- a/endpoint_serial_test.go +++ b/endpoint_serial_test.go @@ -110,7 +110,7 @@ func TestEndpointSerial(t *testing.T) { Channel: evt.(*EventFrame).Channel, }, evt) - node.WriteMessageAll(&MessageHeartbeat{ + err := node.WriteMessageAll(&MessageHeartbeat{ Type: 6, Autopilot: 5, BaseMode: 4, @@ -118,6 +118,7 @@ func TestEndpointSerial(t *testing.T) { SystemStatus: 2, MavlinkVersion: 1, }) + require.NoError(t, err) } <-done @@ -251,7 +252,7 @@ func TestEndpointSerialReconnect(t *testing.T) { Channel: evt.(*EventFrame).Channel, }, evt) - node.WriteMessageAll(&MessageHeartbeat{ + err = node.WriteMessageAll(&MessageHeartbeat{ Type: 6, Autopilot: 5, BaseMode: 4, @@ -259,6 +260,7 @@ func TestEndpointSerialReconnect(t *testing.T) { SystemStatus: 2, MavlinkVersion: 1, }) + require.NoError(t, err) evt = <-node.Events() require.Equal(t, &EventChannelClose{ @@ -270,7 +272,7 @@ func TestEndpointSerialReconnect(t *testing.T) { Channel: evt.(*EventChannelOpen).Channel, }, evt) - node.WriteMessageAll(&MessageHeartbeat{ + err = node.WriteMessageAll(&MessageHeartbeat{ Type: 7, Autopilot: 5, BaseMode: 4, @@ -278,6 +280,7 @@ func TestEndpointSerialReconnect(t *testing.T) { SystemStatus: 2, MavlinkVersion: 1, }) + require.NoError(t, err) <-done } diff --git a/endpoint_server_test.go b/endpoint_server_test.go index 3096b7343..5f905c376 100644 --- a/endpoint_server_test.go +++ b/endpoint_server_test.go @@ -87,7 +87,8 @@ func TestEndpointServer(t *testing.T) { SystemStatus: 2, MavlinkVersion: 1, } - node.WriteMessageAll(msg) + err = node.WriteMessageAll(msg) + require.NoError(t, err) fr, err := rw.Read() require.NoError(t, err) diff --git a/node.go b/node.go index 579456e35..5e6b684fc 100644 --- a/node.go +++ b/node.go @@ -266,30 +266,17 @@ outer: if _, ok := n.channels[req.ch]; !ok { continue } - - var err error - req.what, err = n.encodeMessage(req.what) - if err == nil { - req.ch.write(req.what) - } + req.ch.write(req.what) case what := <-n.chWriteAll: - var err error - what, err = n.encodeMessage(what) - if err == nil { - for ch := range n.channels { - ch.write(what) - } + for ch := range n.channels { + ch.write(what) } case req := <-n.chWriteExcept: - var err error - req.what, err = n.encodeMessage(req.what) - if err == nil { - for ch := range n.channels { - if ch != req.except { - ch.write(req.what) - } + for ch := range n.channels { + if ch != req.except { + ch.write(req.what) } } @@ -322,7 +309,7 @@ outer: // FixFrame recomputes the Frame checksum and signature. // This can be called on Frames whose content has been edited. func (n *Node) FixFrame(fr frame.Frame) error { - _, err := n.encodeMessage(fr) + err := n.encodeFrame(fr) if err != nil { return err } @@ -352,49 +339,47 @@ func (n *Node) FixFrame(fr frame.Frame) error { return nil } -// encode messages once before sending them to the channel's frame.ReadWriter. -func (n *Node) encodeMessage(what interface{}) (interface{}, error) { - switch twhat := what.(type) { - case message.Message: - if _, ok := twhat.(*message.MessageRaw); !ok { - if n.dialectRW == nil { - return nil, fmt.Errorf("dialect is nil") - } +func (n *Node) encodeFrame(fr frame.Frame) error { + if _, ok := fr.GetMessage().(*message.MessageRaw); !ok { + if n.dialectRW == nil { + return fmt.Errorf("dialect is nil") + } - mp := n.dialectRW.GetMessage(twhat.GetID()) - if mp == nil { - return nil, fmt.Errorf("message is not in the dialect") - } + mp := n.dialectRW.GetMessage(fr.GetMessage().GetID()) + if mp == nil { + return fmt.Errorf("message is not in the dialect") + } - msgRaw := mp.Write(twhat, n.conf.OutVersion == V2) + _, isV2 := fr.(*frame.V2Frame) + msgRaw := mp.Write(fr.GetMessage(), isV2) - return msgRaw, nil + switch fr := fr.(type) { + case *frame.V1Frame: + fr.Message = msgRaw + case *frame.V2Frame: + fr.Message = msgRaw } + } - case frame.Frame: - if _, ok := twhat.GetMessage().(*message.MessageRaw); !ok { - if n.dialectRW == nil { - return nil, fmt.Errorf("dialect is nil") - } - - mp := n.dialectRW.GetMessage(twhat.GetMessage().GetID()) - if mp == nil { - return nil, fmt.Errorf("message is not in the dialect") - } + return nil +} - _, isV2 := twhat.(*frame.V2Frame) - msgRaw := mp.Write(twhat.GetMessage(), isV2) +func (n *Node) encodeMessage(msg message.Message) (message.Message, error) { + if _, ok := msg.(*message.MessageRaw); !ok { + if n.dialectRW == nil { + return nil, fmt.Errorf("dialect is nil") + } - switch ff := twhat.(type) { - case *frame.V1Frame: - ff.Message = msgRaw - case *frame.V2Frame: - ff.Message = msgRaw - } + mp := n.dialectRW.GetMessage(msg.GetID()) + if mp == nil { + return nil, fmt.Errorf("message is not in the dialect") } + + msgRaw := mp.Write(msg, n.conf.OutVersion == V2) + return msgRaw, nil } - return what, nil + return msg, nil } // Events returns a channel from which receiving events. Possible events are: @@ -411,57 +396,99 @@ func (n *Node) Events() chan Event { } // WriteMessageTo writes a message to given channel. -func (n *Node) WriteMessageTo(channel *Channel, m message.Message) { +func (n *Node) WriteMessageTo(channel *Channel, m message.Message) error { + m, err := n.encodeMessage(m) + if err != nil { + return err + } + select { case n.chWriteTo <- writeToReq{channel, m}: case <-n.terminate: } + + return nil } // WriteMessageAll writes a message to all channels. -func (n *Node) WriteMessageAll(m message.Message) { +func (n *Node) WriteMessageAll(m message.Message) error { + m, err := n.encodeMessage(m) + if err != nil { + return err + } + select { case n.chWriteAll <- m: case <-n.terminate: } + + return nil } // WriteMessageExcept writes a message to all channels except specified channel. -func (n *Node) WriteMessageExcept(exceptChannel *Channel, m message.Message) { +func (n *Node) WriteMessageExcept(exceptChannel *Channel, m message.Message) error { + m, err := n.encodeMessage(m) + if err != nil { + return err + } + select { case n.chWriteExcept <- writeExceptReq{exceptChannel, m}: case <-n.terminate: } + + return nil } // WriteFrameTo writes a frame to given channel. // This function is intended only for routing pre-existing frames to other nodes, // since all frame fields must be filled manually. -func (n *Node) WriteFrameTo(channel *Channel, fr frame.Frame) { +func (n *Node) WriteFrameTo(channel *Channel, fr frame.Frame) error { + err := n.encodeFrame(fr) + if err != nil { + return err + } + select { case n.chWriteTo <- writeToReq{channel, fr}: case <-n.terminate: } + + return nil } // WriteFrameAll writes a frame to all channels. // This function is intended only for routing pre-existing frames to other nodes, // since all frame fields must be filled manually. -func (n *Node) WriteFrameAll(fr frame.Frame) { +func (n *Node) WriteFrameAll(fr frame.Frame) error { + err := n.encodeFrame(fr) + if err != nil { + return err + } + select { case n.chWriteAll <- fr: case <-n.terminate: } + + return nil } // WriteFrameExcept writes a frame to all channels except specified channel. // This function is intended only for routing pre-existing frames to other nodes, // since all frame fields must be filled manually. -func (n *Node) WriteFrameExcept(exceptChannel *Channel, fr frame.Frame) { +func (n *Node) WriteFrameExcept(exceptChannel *Channel, fr frame.Frame) error { + err := n.encodeFrame(fr) + if err != nil { + return err + } + select { case n.chWriteExcept <- writeExceptReq{exceptChannel, fr}: case <-n.terminate: } + + return nil } func (n *Node) pushEvent(evt Event) { diff --git a/node_heartbeat.go b/node_heartbeat.go index ba489b181..f19a56fb7 100644 --- a/node_heartbeat.go +++ b/node_heartbeat.go @@ -77,7 +77,7 @@ func (h *nodeHeartbeat) run() { m.Elem().FieldByName("CustomMode").SetUint(0) m.Elem().FieldByName("SystemStatus").SetUint(4) // MAV_STATE_ACTIVE m.Elem().FieldByName("MavlinkVersion").SetUint(uint64(h.n.conf.Dialect.Version)) - h.n.WriteMessageAll(m.Interface().(message.Message)) + h.n.WriteMessageAll(m.Interface().(message.Message)) //nolint:errcheck case <-h.terminate: return diff --git a/node_streamrequest.go b/node_streamrequest.go index d5a49ea85..56bbe2b8c 100644 --- a/node_streamrequest.go +++ b/node_streamrequest.go @@ -171,7 +171,7 @@ func (sr *nodeStreamRequest) onEventFrame(evt *EventFrame) { m.Elem().FieldByName("ReqStreamId").SetUint(uint64(stream)) m.Elem().FieldByName("ReqMessageRate").SetUint(uint64(sr.n.conf.StreamRequestFrequency)) m.Elem().FieldByName("StartStop").SetUint(uint64(1)) - sr.n.WriteMessageTo(evt.Channel, m.Interface().(message.Message)) + sr.n.WriteMessageTo(evt.Channel, m.Interface().(message.Message)) //nolint:errcheck } sr.n.pushEvent(&EventStreamRequested{ diff --git a/node_test.go b/node_test.go index 3ea6ab9af..d34c0e1f4 100644 --- a/node_test.go +++ b/node_test.go @@ -101,7 +101,8 @@ func TestNodeCloseInLoop(t *testing.T) { Channel: evt.(*EventChannelOpen).Channel, }, evt) - node2.WriteMessageAll(testMessage) + err = node2.WriteMessageAll(testMessage) + require.NoError(t, err) for evt := range node1.Events() { if _, ok := evt.(*EventChannelOpen); ok { @@ -171,15 +172,17 @@ func TestNodeWriteAll(t *testing.T) { } if ca == "message" { - server.WriteMessageAll(testMessage) + err := server.WriteMessageAll(testMessage) + require.NoError(t, err) } else { - server.WriteFrameAll(&frame.V2Frame{ + err := server.WriteFrameAll(&frame.V2Frame{ SequenceID: 0, SystemID: 11, ComponentID: 1, Message: testMessage, Checksum: 55967, }) + require.NoError(t, err) } wg.Wait() }) @@ -251,15 +254,17 @@ func TestNodeWriteExcept(t *testing.T) { } if ca == "message" { - server.WriteMessageExcept(except, testMessage) + err := server.WriteMessageExcept(except, testMessage) + require.NoError(t, err) } else { - server.WriteFrameExcept(except, &frame.V2Frame{ + err := server.WriteFrameExcept(except, &frame.V2Frame{ SequenceID: 0, SystemID: 11, ComponentID: 1, Message: testMessage, Checksum: 55967, }) + require.NoError(t, err) } wg.Wait() }) @@ -330,15 +335,17 @@ func TestNodeWriteTo(t *testing.T) { } if ca == "message" { - server.WriteMessageTo(except, testMessage) + err := server.WriteMessageTo(except, testMessage) + require.NoError(t, err) } else { - server.WriteFrameTo(except, &frame.V2Frame{ + err := server.WriteFrameTo(except, &frame.V2Frame{ SequenceID: 0, SystemID: 11, ComponentID: 1, Message: testMessage, Checksum: 55967, }) + require.NoError(t, err) } <-recv }) @@ -375,12 +382,14 @@ func TestNodeWriteMessageInLoop(t *testing.T) { Channel: evt.(*EventChannelOpen).Channel, }, evt) - node2.WriteMessageAll(testMessage) + err = node2.WriteMessageAll(testMessage) + require.NoError(t, err) for evt := range node1.Events() { if _, ok := evt.(*EventChannelOpen); ok { for i := 0; i < 10; i++ { - node1.WriteMessageAll(testMessage) + err := node1.WriteMessageAll(testMessage) + require.NoError(t, err) } break } @@ -424,7 +433,8 @@ func TestNodeSignature(t *testing.T) { Channel: evt.(*EventChannelOpen).Channel, }, evt) - node2.WriteMessageAll(testMessage) + err = node2.WriteMessageAll(testMessage) + require.NoError(t, err) <-node1.Events() evt = <-node1.Events() @@ -484,14 +494,17 @@ func TestNodeRoute(t *testing.T) { require.NoError(t, err) defer node3.Close() - node1.WriteMessageAll(testMessage) + err = node1.WriteMessageAll(testMessage) + require.NoError(t, err) <-node2.Events() <-node2.Events() evt := <-node2.Events() fr, ok := evt.(*EventFrame) require.Equal(t, true, ok) - node2.WriteFrameExcept(fr.Channel, fr.Frame) + + err = node2.WriteFrameExcept(fr.Channel, fr.Frame) + require.NoError(t, err) <-node3.Events() evt = <-node3.Events() @@ -552,7 +565,9 @@ func TestNodeFixFrame(t *testing.T) { err = node2.FixFrame(fra) require.NoError(t, err) - node2.WriteFrameAll(fra) + + err = node2.WriteFrameAll(fra) + require.NoError(t, err) <-node1.Events() evt = <-node1.Events() @@ -573,3 +588,87 @@ func TestNodeFixFrame(t *testing.T) { Channel: fr.Channel, }, evt) } + +func TestNodeWriteSameToMultiple(t *testing.T) { + server, err := NewNode(NodeConf{ + Dialect: testDialect, + OutVersion: V2, + OutSystemID: 11, + Endpoints: []EndpointConf{ + EndpointTCPServer{"127.0.0.1:5600"}, + }, + HeartbeatDisable: true, + }) + require.NoError(t, err) + defer server.Close() + + client1, err := NewNode(NodeConf{ + Dialect: testDialect, + OutVersion: V2, + OutSystemID: 11, + Endpoints: []EndpointConf{ + EndpointTCPClient{"127.0.0.1:5600"}, + }, + HeartbeatDisable: true, + }) + require.NoError(t, err) + defer client1.Close() + + client2, err := NewNode(NodeConf{ + Dialect: testDialect, + OutVersion: V2, + OutSystemID: 11, + Endpoints: []EndpointConf{ + EndpointTCPClient{"127.0.0.1:5600"}, + }, + HeartbeatDisable: true, + }) + require.NoError(t, err) + defer client2.Close() + + evt := <-client1.Events() + _, ok := evt.(*EventChannelOpen) + require.Equal(t, true, ok) + + evt = <-client2.Events() + _, ok = evt.(*EventChannelOpen) + require.Equal(t, true, ok) + + evt = <-server.Events() + _, ok = evt.(*EventChannelOpen) + require.Equal(t, true, ok) + + evt = <-server.Events() + _, ok = evt.(*EventChannelOpen) + require.Equal(t, true, ok) + + fr := &frame.V2Frame{ + SequenceID: 0, + SystemID: 11, + ComponentID: 1, + Message: testMessage, + Checksum: 55967, + } + + err = client1.WriteFrameAll(fr) + require.NoError(t, err) + + err = client2.WriteFrameAll(fr) + require.NoError(t, err) + + for i := 0; i < 2; i++ { + evt = <-server.Events() + fr, ok := evt.(*EventFrame) + require.Equal(t, true, ok) + require.Equal(t, &EventFrame{ + Frame: &frame.V2Frame{ + SequenceID: 0, + SystemID: 11, + ComponentID: 1, + Message: testMessage, + Checksum: 55967, + }, + Channel: fr.Channel, + }, evt) + } +}