diff --git a/interface.go b/interface.go index b6ab805..4a66222 100644 --- a/interface.go +++ b/interface.go @@ -31,6 +31,8 @@ type PeerI interface { RequestBlock(blockHash *chainhash.Hash) Network() wire.BitcoinNet IsHealthy() bool + Shutdown() + Restart() } type PeerHandlerI interface { diff --git a/peer.go b/peer.go index c2c9edf..ff84e96 100644 --- a/peer.go +++ b/peer.go @@ -128,6 +128,12 @@ func NewPeer(logger *slog.Logger, address string, peerHandler PeerHandlerI, netw } } + p.start() + + return p, nil +} + +func (p *Peer) start() { ctx, cancelAll := context.WithCancel(context.Background()) p.cancelAll = cancelAll p.ctx = ctx @@ -143,20 +149,18 @@ func NewPeer(logger *slog.Logger, address string, peerHandler PeerHandlerI, netw if p.incomingConn != nil { go func() { - connectErr := p.connectAndStartReadWriteHandlers() - if connectErr != nil { + err := p.connectAndStartReadWriteHandlers() + if err != nil { p.logger.Warn("Failed to connect to peer", slog.String(errKey, err.Error())) } }() p.logger.Info("Incoming connection from peer") - return p, nil + return } // reconnect if disconnected, but only on outgoing connections p.reconnectingWg.Add(1) go p.reconnect() - - return p, nil } func (p *Peer) disconnectLock() { @@ -848,6 +852,12 @@ func (p *Peer) stopWriteHandler() { p.writerWg.Wait() } +func (p *Peer) Restart() { + p.Shutdown() + + p.start() +} + func (p *Peer) Shutdown() { p.cancelAll() diff --git a/peer_mock.go b/peer_mock.go index 53b54c2..f60f51f 100644 --- a/peer_mock.go +++ b/peer_mock.go @@ -121,12 +121,8 @@ func (p *PeerMock) message(msg wire.Message) { p.messages = append(p.messages, msg) } -// func (p *PeerMock) getMessages() []wire.Message { -// p.mu.Lock() -// defer p.mu.Unlock() - -// return p.messages -// } +func (p *PeerMock) Shutdown() {} +func (p *PeerMock) Restart() {} func (p *PeerMock) WriteMsg(msg wire.Message) error { p.writeChan <- msg diff --git a/peer_test.go b/peer_test.go index 7304408..a5f9c2a 100644 --- a/peer_test.go +++ b/peer_test.go @@ -136,7 +136,7 @@ func TestShutdown(t *testing.T) { break connectLoop } case <-time.NewTimer(1 * time.Second).C: - t.Fatal("peer did not disconnect") + t.Fatal("peer did not connect") } } @@ -152,6 +152,120 @@ func TestShutdown(t *testing.T) { } } +func TestRestart(t *testing.T) { + tt := []struct { + name string + }{ + { + name: "Restart", + }, + } + + for _, tc := range tt { + t.Run(tc.name, func(t *testing.T) { + peerConn, myConn := connutil.AsyncPipe() + + peerHandler := NewMockPeerHandler() + logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})) + p, err := NewPeer( + logger, + "MockPeerHandler:0000", + peerHandler, + wire.MainNet, + WithDialer(func(network, address string) (net.Conn, error) { + return peerConn, nil + }), + WithRetryReadWriteMessageInterval(200*time.Millisecond), + ) + require.NoError(t, err) + + t.Log("handshake 1") + handshakeFinished := make(chan struct{}) + go func() { + doHandshake(t, p, myConn) + handshakeFinished <- struct{}{} + }() + + select { + case <-handshakeFinished: + t.Log("handshake 1 finished") + case <-time.After(5 * time.Second): + t.Fatal("handshake 1 timeout") + } + // wait for the peer to be connected + connectLoop: + for { + select { + case <-time.NewTicker(10 * time.Millisecond).C: + if p.Connected() { + break connectLoop + } + case <-time.NewTimer(1 * time.Second).C: + t.Fatal("peer did not connect") + } + } + + invMsg := wire.NewMsgInv() + hash, err := chainhash.NewHashFromStr(tx1) + require.NoError(t, err) + err = invMsg.AddInvVect(wire.NewInvVect(wire.InvTypeTx, hash)) + require.NoError(t, err) + + t.Log("restart") + p.Restart() + + // wait for the peer to be disconnected + disconnectLoop: + for { + select { + case <-time.NewTicker(10 * time.Millisecond).C: + if !p.Connected() { + break disconnectLoop + } + case <-time.NewTimer(5 * time.Second).C: + t.Fatal("peer did not disconnect") + } + } + + //time.Sleep(15 * time.Second) + // recreate connection + p.mu.Lock() + peerConn, myConn = connutil.AsyncPipe() + p.mu.Unlock() + t.Log("new connection created") + + time.Sleep(5 * time.Second) + t.Log("handshake 2") + + go func() { + doHandshake(t, p, myConn) + handshakeFinished <- struct{}{} + }() + + select { + case <-handshakeFinished: + t.Log("handshake 2 finished") + case <-time.After(5 * time.Second): + t.Fatal("handshake 2 timeout") + } + + t.Log("reconnect") + // wait for the peer to be reconnected + reconnectLoop: + for { + select { + case <-time.NewTicker(10 * time.Millisecond).C: + if p.Connected() { + break reconnectLoop + } + case <-time.NewTimer(1 * time.Second).C: + t.Fatal("peer did not reconnect") + } + } + }) + } +} + func TestReconnect(t *testing.T) { tt := []struct { name string