diff --git a/interface.go b/interface.go index 6d98c9d..4a66222 100644 --- a/interface.go +++ b/interface.go @@ -18,6 +18,7 @@ type PeerManagerI interface { RequestBlock(blockHash *chainhash.Hash) PeerI AddPeer(peer PeerI) error GetPeers() []PeerI + Shutdown() } type PeerI interface { @@ -30,6 +31,8 @@ type PeerI interface { RequestBlock(blockHash *chainhash.Hash) Network() wire.BitcoinNet IsHealthy() bool + Shutdown() + Restart() } type PeerHandlerI interface { diff --git a/peer.go b/peer.go index 6e9a31b..ff84e96 100644 --- a/peer.go +++ b/peer.go @@ -35,6 +35,7 @@ const ( sentMsg = "Sent" receivedMsg = "Recv" + nrWriteHandlersDefault = 10 retryReadWriteMessageIntervalDefault = 1 * time.Second retryReadWriteMessageAttempts = 5 reconnectInterval = 10 * time.Second @@ -74,6 +75,7 @@ type Peer struct { userAgentName *string userAgentVersion *string retryReadWriteMessageInterval time.Duration + nrWriteHandlers int ctx context.Context @@ -107,6 +109,7 @@ func NewPeer(logger *slog.Logger, address string, peerHandler PeerHandlerI, netw peerHandler: peerHandler, logger: peerLogger, dial: net.Dial, + nrWriteHandlers: nrWriteHandlersDefault, maximumMessageSize: defaultMaximumMessageSize, batchDelay: defaultBatchDelayMilliseconds * time.Millisecond, retryReadWriteMessageInterval: retryReadWriteMessageIntervalDefault, @@ -125,6 +128,12 @@ func NewPeer(logger *slog.Logger, address string, peerHandler PeerHandlerI, netw } } + p.start() + + return p, nil +} + +func (p *Peer) start() { ctx, cancelAll := context.WithCancel(context.Background()) p.cancelAll = cancelAll p.ctx = ctx @@ -140,20 +149,18 @@ func NewPeer(logger *slog.Logger, address string, peerHandler PeerHandlerI, netw if p.incomingConn != nil { go func() { - connectErr := p.connectAndStartReadWriteHandlers() - if connectErr != nil { + err := p.connectAndStartReadWriteHandlers() + if err != nil { p.logger.Warn("Failed to connect to peer", slog.String(errKey, err.Error())) } }() p.logger.Info("Incoming connection from peer") - return p, nil + return } // reconnect if disconnected, but only on outgoing connections p.reconnectingWg.Add(1) go p.reconnect() - - return p, nil } func (p *Peer) disconnectLock() { @@ -230,7 +237,7 @@ func (p *Peer) connectAndStartReadWriteHandlers() error { writerCtx, cancelWriter := context.WithCancel(p.ctx) p.cancelWriteHandler = cancelWriter - for i := 0; i < 10; i++ { + for i := 0; i < p.nrWriteHandlers; i++ { // start 10 workers that will write to the peer // locking is done in the net.write in the wire/message handler // this reduces the wait on the writer when processing writes (for example HandleTransactionSent) @@ -778,7 +785,7 @@ func (p *Peer) pingHandler() { case <-pingTicker.C: nonce, err := wire.RandomUint64() if err != nil { - p.logger.Error("Not sending ping", slog.String(errKey, err.Error())) + p.logger.Error("Failed to create random nonce - not sending ping", slog.String(errKey, err.Error())) continue } p.writeChan <- wire.NewMsgPing(nonce) @@ -810,6 +817,7 @@ func (p *Peer) monitorConnectionHealth() { case <-checkConnectionHealthTicker.C: p.mu.Lock() p.isHealthy = false + p.logger.Warn("peer unhealthy") p.mu.Unlock() case <-p.ctx.Done(): return @@ -844,6 +852,12 @@ func (p *Peer) stopWriteHandler() { p.writerWg.Wait() } +func (p *Peer) Restart() { + p.Shutdown() + + p.start() +} + func (p *Peer) Shutdown() { p.cancelAll() diff --git a/peer_manager.go b/peer_manager.go index 524309a..a10dd0d 100644 --- a/peer_manager.go +++ b/peer_manager.go @@ -67,6 +67,12 @@ func (pm *PeerManager) GetPeers() []PeerI { return peers } +func (pm *PeerManager) Shutdown() { + for _, peer := range pm.peers { + peer.Shutdown() + } +} + // AnnounceTransaction will send an INV message to the provided peers or to selected peers if peers is nil // it will return the peers that the transaction was actually announced to func (pm *PeerManager) AnnounceTransaction(txHash *chainhash.Hash, peers []PeerI) []PeerI { @@ -74,11 +80,15 @@ func (pm *PeerManager) AnnounceTransaction(txHash *chainhash.Hash, peers []PeerI peers = pm.GetAnnouncedPeers() } + announcedPeers := make([]PeerI, 0, len(peers)) for _, peer := range peers { - peer.AnnounceTransaction(txHash) + if peer.Connected() && peer.IsHealthy() { + peer.AnnounceTransaction(txHash) + announcedPeers = append(announcedPeers, peer) + } } - return peers + return announcedPeers } func (pm *PeerManager) RequestTransaction(txHash *chainhash.Hash) PeerI { diff --git a/peer_mock.go b/peer_mock.go index 53b54c2..f60f51f 100644 --- a/peer_mock.go +++ b/peer_mock.go @@ -121,12 +121,8 @@ func (p *PeerMock) message(msg wire.Message) { p.messages = append(p.messages, msg) } -// func (p *PeerMock) getMessages() []wire.Message { -// p.mu.Lock() -// defer p.mu.Unlock() - -// return p.messages -// } +func (p *PeerMock) Shutdown() {} +func (p *PeerMock) Restart() {} func (p *PeerMock) WriteMsg(msg wire.Message) error { p.writeChan <- msg diff --git a/peer_options.go b/peer_options.go index b492e7f..432ebeb 100644 --- a/peer_options.go +++ b/peer_options.go @@ -54,3 +54,10 @@ func WithRetryReadWriteMessageInterval(d time.Duration) PeerOptions { return nil } } + +func WithNrOfWriteHandlers(NrWriteHandlers int) PeerOptions { + return func(p *Peer) error { + p.nrWriteHandlers = NrWriteHandlers + return nil + } +} diff --git a/peer_test.go b/peer_test.go index 7304408..a5f9c2a 100644 --- a/peer_test.go +++ b/peer_test.go @@ -136,7 +136,7 @@ func TestShutdown(t *testing.T) { break connectLoop } case <-time.NewTimer(1 * time.Second).C: - t.Fatal("peer did not disconnect") + t.Fatal("peer did not connect") } } @@ -152,6 +152,120 @@ func TestShutdown(t *testing.T) { } } +func TestRestart(t *testing.T) { + tt := []struct { + name string + }{ + { + name: "Restart", + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + peerConn, myConn := connutil.AsyncPipe() + + peerHandler := NewMockPeerHandler() + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + p, err := NewPeer( + logger, + "MockPeerHandler:0000", + peerHandler, + wire.MainNet, + WithDialer(func(network, address string) (net.Conn, error) { + return peerConn, nil + }), + WithRetryReadWriteMessageInterval(200*time.Millisecond), + ) + require.NoError(t, err) + + t.Log("handshake 1") + handshakeFinished := make(chan struct{}) + go func() { + doHandshake(t, p, myConn) + handshakeFinished <- struct{}{} + }() + + select { + case <-handshakeFinished: + t.Log("handshake 1 finished") + case <-time.After(5 * time.Second): + t.Fatal("handshake 1 timeout") + } + // wait for the peer to be connected + connectLoop: + for { + select { + case <-time.NewTicker(10 * time.Millisecond).C: + if p.Connected() { + break connectLoop + } + case <-time.NewTimer(1 * time.Second).C: + t.Fatal("peer did not connect") + } + } + + invMsg := wire.NewMsgInv() + hash, err := chainhash.NewHashFromStr(tx1) + require.NoError(t, err) + err = invMsg.AddInvVect(wire.NewInvVect(wire.InvTypeTx, hash)) + require.NoError(t, err) + + t.Log("restart") + p.Restart() + + // wait for the peer to be disconnected + disconnectLoop: + for { + select { + case <-time.NewTicker(10 * time.Millisecond).C: + if !p.Connected() { + break disconnectLoop + } + case <-time.NewTimer(5 * time.Second).C: + t.Fatal("peer did not disconnect") + } + } + + //time.Sleep(15 * time.Second) + // recreate connection + p.mu.Lock() + peerConn, myConn = connutil.AsyncPipe() + p.mu.Unlock() + t.Log("new connection created") + + time.Sleep(5 * time.Second) + t.Log("handshake 2") + + go func() { + doHandshake(t, p, myConn) + handshakeFinished <- struct{}{} + }() + + select { + case <-handshakeFinished: + t.Log("handshake 2 finished") + case <-time.After(5 * time.Second): + t.Fatal("handshake 2 timeout") + } + + t.Log("reconnect") + // wait for the peer to be reconnected + reconnectLoop: + for { + select { + case <-time.NewTicker(10 * time.Millisecond).C: + if p.Connected() { + break reconnectLoop + } + case <-time.NewTimer(1 * time.Second).C: + t.Fatal("peer did not reconnect") + } + } + }) + } +} + func TestReconnect(t *testing.T) { tt := []struct { name string