Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

peer: test for startup writeHander data race #8198

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions peer/brontide_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1344,3 +1344,109 @@ func TestHandleRemovePendingChannel(t *testing.T) {
})
}
}

// TestStartupWriteMessageRace checks that no data race occurs when starting up
// a peer with an existing channel, while an outgoing message is queuing. Such
// a race occurred in https://github.com/lightningnetwork/lnd/issues/8184, where
// a channel reestablish message raced with another outgoing message.
//
// Note that races will only be detected with the Go race detector enabled.
func TestStartupWriteMessageRace(t *testing.T) {
t.Parallel()

// Set up parameters for createTestPeer.
notifier := &mock.ChainNotifier{
SpendChan: make(chan *chainntnfs.SpendDetail),
EpochChan: make(chan *chainntnfs.BlockEpoch),
ConfChan: make(chan *chainntnfs.TxConfirmation),
}
broadcastTxChan := make(chan *wire.MsgTx)
mockSwitch := &mockMessageSwitch{}

// Use a callback to extract the channel created by createTestPeer, so
// we can mark it borked below. We can't mark it borked within the
// callback, since the channel hasn't been saved to the DB yet when the
// callback executes.
var channel *channeldb.OpenChannel
getChannels := func(a, b *channeldb.OpenChannel) {
channel = a
}

// createTestPeer creates a peer and a channel with that peer.
peer, _, err := createTestPeer(
t, notifier, broadcastTxChan, getChannels, mockSwitch,
)
require.NoError(t, err, "unable to create test channel")

// Avoid the need to mock the channel graph by marking the channel
// borked. Borked channels still get a reestablish message sent on
// reconnect, while skipping channel graph checks and link creation.
require.NoError(t, channel.MarkBorked())

// Use a mock conn to detect read/write races on the conn.
mockConn := newMockConn(t, 2)
peer.cfg.Conn = mockConn

// Set up other configuration necessary to successfully execute
// peer.Start().
peer.cfg.LegacyFeatures = lnwire.EmptyFeatureVector()
writeBufferPool := pool.NewWriteBuffer(
pool.DefaultWriteBufferGCInterval,
pool.DefaultWriteBufferExpiryInterval,
)
writePool := pool.NewWrite(
writeBufferPool, 1, timeout,
)
require.NoError(t, writePool.Start())
peer.cfg.WritePool = writePool
readBufferPool := pool.NewReadBuffer(
pool.DefaultReadBufferGCInterval,
pool.DefaultReadBufferExpiryInterval,
)
readPool := pool.NewRead(
readBufferPool, 1, timeout,
)
require.NoError(t, readPool.Start())
peer.cfg.ReadPool = readPool

// Send a message while starting the peer. As the peer starts up, it
// should not trigger a data race between the sending of this message
// and the sending of the channel reestablish message.
sendPingDone := make(chan struct{})
go func() {
require.NoError(t, peer.SendMessage(true, lnwire.NewPing(0)))
close(sendPingDone)
}()

// Handle init messages.
go func() {
// Read init message.
<-mockConn.writtenMessages

// Write the init reply message.
initReplyMsg := lnwire.NewInitMessage(
lnwire.NewRawFeatureVector(
lnwire.DataLossProtectRequired,
),
lnwire.NewRawFeatureVector(),
)
var b bytes.Buffer
_, err = lnwire.WriteMessage(&b, initReplyMsg, 0)
require.NoError(t, err)

mockConn.readMessages <- b.Bytes()
}()

// Start the peer. No data race should occur.
require.NoError(t, peer.Start())

// Ensure messages were sent during startup.
<-sendPingDone
for i := 0; i < 2; i++ {
select {
case <-mockConn.writtenMessages:
default:
t.Fatalf("Failed to send all messages during startup")
}
}
}
20 changes: 20 additions & 0 deletions peer/test_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,16 @@ type mockMessageConn struct {

readMessages chan []byte
curReadMessage []byte

// writeRaceDetectingCounter is incremented on any function call
// associated with writing to the connection. The race detector will
// trigger on this counter if a data race exists.
writeRaceDetectingCounter int
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh nice little trick: unsafe concurrent access here will cause the race detector to fire!


// readRaceDetectingCounter is incremented on any function call
// associated with reading from the connection. The race detector will
// trigger on this counter if a data race exists.
readRaceDetectingCounter int
}

func newMockConn(t *testing.T, expectedMessages int) *mockMessageConn {
Expand All @@ -509,17 +519,20 @@ func newMockConn(t *testing.T, expectedMessages int) *mockMessageConn {

// SetWriteDeadline mocks setting write deadline for our conn.
func (m *mockMessageConn) SetWriteDeadline(time.Time) error {
m.writeRaceDetectingCounter++
return nil
}

// Flush mocks a message conn flush.
func (m *mockMessageConn) Flush() (int, error) {
m.writeRaceDetectingCounter++
return 0, nil
}

// WriteMessage mocks sending of a message on our connection. It will push
// the bytes sent into the mock's writtenMessages channel.
func (m *mockMessageConn) WriteMessage(msg []byte) error {
m.writeRaceDetectingCounter++
select {
case m.writtenMessages <- msg:
case <-time.After(timeout):
Expand All @@ -542,15 +555,18 @@ func (m *mockMessageConn) assertWrite(expected []byte) {
}

func (m *mockMessageConn) SetReadDeadline(t time.Time) error {
m.readRaceDetectingCounter++
return nil
}

func (m *mockMessageConn) ReadNextHeader() (uint32, error) {
m.readRaceDetectingCounter++
m.curReadMessage = <-m.readMessages
return uint32(len(m.curReadMessage)), nil
}

func (m *mockMessageConn) ReadNextBody(buf []byte) ([]byte, error) {
m.readRaceDetectingCounter++
return m.curReadMessage, nil
}

Expand All @@ -561,3 +577,7 @@ func (m *mockMessageConn) RemoteAddr() net.Addr {
func (m *mockMessageConn) LocalAddr() net.Addr {
return nil
}

func (m *mockMessageConn) Close() error {
return nil
}
Loading