Skip to content

Commit

Permalink
Graceful shutdown of read handler go routine
Browse files Browse the repository at this point in the history
  • Loading branch information
boecklim committed Mar 19, 2024
1 parent b7d3e09 commit ca9e440
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 119 deletions.
260 changes: 141 additions & 119 deletions Peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ type Peer struct {
dataBatcher *batcher.Batcher[chainhash.Hash]
maximumMessageSize int64
isHealthy bool
quitReadHandler chan struct{}
}

// NewPeer returns a new bitcoin peer for the provided address and configuration.
Expand Down Expand Up @@ -192,7 +193,7 @@ func (p *Peer) connect() error {
p.readConn = conn
}

go p.readHandler()
p.startReadHandler()

// write version message to our peer directly and not through the write channel,
// write channel is not ready to send message until the VERACK handshake is done
Expand Down Expand Up @@ -278,150 +279,163 @@ func (p *Peer) readRetry(r io.Reader, pver uint32, bsvnet wire.BitcoinNet) (wire
return msg, nil
}

func (p *Peer) readHandler() {
readConn := p.readConn
func (p *Peer) startReadHandler() {
p.quitReadHandler = make(chan struct{})

if readConn == nil {
p.logger.Error("no connection")
return
}
go func() {

reader := bufio.NewReader(&io.LimitedReader{R: readConn, N: p.maximumMessageSize})
for {
msg, err := p.readRetry(reader, wire.ProtocolVersion, p.network)
if err != nil {
p.logger.Error("Retrying to read failed", slog.String(errKey, err.Error()))
readConn := p.readConn

// by disconnecting ensure that peer will try to reconnect
p.disconnect()
if readConn == nil {
p.logger.Error("no connection")
return
}

commandLogger := p.logger.With(slog.String(commandKey, strings.ToUpper(msg.Command())))
reader := bufio.NewReader(&io.LimitedReader{R: readConn, N: p.maximumMessageSize})
for {
select {
case <-p.quitReadHandler:
return
default:
msg, err := p.readRetry(reader, wire.ProtocolVersion, p.network)
if err != nil {
p.logger.Error("Retrying to read failed", slog.String(errKey, err.Error()))

// we could check this based on type (switch msg.(type)) but that would not allow
// us to override the default behaviour for a specific message type
switch msg.Command() {
case wire.CmdVersion:
commandLogger.Debug(receivedMsg)
if p.sentVerAck.Load() {
commandLogger.Warn("Received version message after sending verack")
continue
}
p.disconnect()

verackMsg := wire.NewMsgVerAck()
if err = wire.WriteMessage(readConn, verackMsg, wire.ProtocolVersion, p.network); err != nil {
commandLogger.Error("failed to write message", slog.String(errKey, err.Error()))
}
commandLogger.Debug(sentMsg, slog.String(commandKey, strings.ToUpper(verackMsg.Command())))
p.sentVerAck.Store(true)
p.mu.Lock()
p.quitReadHandler = nil
p.mu.Unlock()

case wire.CmdPing:
commandLogger.Debug(receivedMsg, slog.String(commandKey, strings.ToUpper(wire.CmdPing)))
p.pingPongAlive <- struct{}{}
return
}

pingMsg, ok := msg.(*wire.MsgPing)
if !ok {
continue
}
p.writeChan <- wire.NewMsgPong(pingMsg.Nonce)
commandLogger := p.logger.With(slog.String(commandKey, strings.ToUpper(msg.Command())))

case wire.CmdInv:
invMsg, ok := msg.(*wire.MsgInv)
if !ok {
continue
}
for _, inv := range invMsg.InvList {
commandLogger.Debug(receivedMsg, slog.String(hashKey, inv.Hash.String()), slog.String(typeKey, inv.Type.String()))
}
// we could check this based on type (switch msg.(type)) but that would not allow
// us to override the default behaviour for a specific message type
switch msg.Command() {
case wire.CmdVersion:
commandLogger.Debug(receivedMsg)
if p.sentVerAck.Load() {
commandLogger.Warn("Received version message after sending verack")
continue
}

go func(invList []*wire.InvVect, routineLogger *slog.Logger) {
for _, invVect := range invList {
switch invVect.Type {
case wire.InvTypeTx:
if err = p.peerHandler.HandleTransactionAnnouncement(invVect, p); err != nil {
commandLogger.Error("Unable to process tx", slog.String(hashKey, invVect.Hash.String()), slog.String(typeKey, invVect.Type.String()), slog.String(errKey, err.Error()))
}
case wire.InvTypeBlock:
if err = p.peerHandler.HandleBlockAnnouncement(invVect, p); err != nil {
commandLogger.Error("Unable to process block", slog.String(hashKey, invVect.Hash.String()), slog.String(typeKey, invVect.Type.String()), slog.String(errKey, err.Error()))
}
verackMsg := wire.NewMsgVerAck()
if err = wire.WriteMessage(readConn, verackMsg, wire.ProtocolVersion, p.network); err != nil {
commandLogger.Error("failed to write message", slog.String(errKey, err.Error()))
}
}
}(invMsg.InvList, commandLogger)
commandLogger.Debug(sentMsg, slog.String(commandKey, strings.ToUpper(verackMsg.Command())))
p.sentVerAck.Store(true)

case wire.CmdGetData:
dataMsg, ok := msg.(*wire.MsgGetData)
if !ok {
continue
}
for _, inv := range dataMsg.InvList {
commandLogger.Debug(receivedMsg, slog.String(hashKey, inv.Hash.String()), slog.String(typeKey, inv.Type.String()))
}
p.handleGetDataMsg(dataMsg, commandLogger)
case wire.CmdPing:
commandLogger.Debug(receivedMsg, slog.String(commandKey, strings.ToUpper(wire.CmdPing)))
p.pingPongAlive <- struct{}{}

case wire.CmdTx:
txMsg, ok := msg.(*wire.MsgTx)
if !ok {
continue
}
commandLogger.Debug(receivedMsg, slog.String(hashKey, txMsg.TxHash().String()), slog.Int("size", txMsg.SerializeSize()))
if err = p.peerHandler.HandleTransaction(txMsg, p); err != nil {
commandLogger.Error("Unable to process tx", slog.String(hashKey, txMsg.TxHash().String()), slog.String(errKey, err.Error()))
}
pingMsg, ok := msg.(*wire.MsgPing)
if !ok {
continue
}
p.writeChan <- wire.NewMsgPong(pingMsg.Nonce)

case wire.CmdBlock:
msgBlock, ok := msg.(*wire.MsgBlock)
if ok {
commandLogger.Info(receivedMsg, slog.String(hashKey, msgBlock.Header.BlockHash().String()))
case wire.CmdInv:
invMsg, ok := msg.(*wire.MsgInv)
if !ok {
continue
}
for _, inv := range invMsg.InvList {
commandLogger.Debug(receivedMsg, slog.String(hashKey, inv.Hash.String()), slog.String(typeKey, inv.Type.String()))
}

err = p.peerHandler.HandleBlock(msgBlock, p)
if err != nil {
commandLogger.Error("Unable to process block", slog.String(hashKey, msgBlock.Header.BlockHash().String()), slog.String(errKey, err.Error()))
}
continue
}
go func(invList []*wire.InvVect, routineLogger *slog.Logger) {
for _, invVect := range invList {
switch invVect.Type {
case wire.InvTypeTx:
if err = p.peerHandler.HandleTransactionAnnouncement(invVect, p); err != nil {
commandLogger.Error("Unable to process tx", slog.String(hashKey, invVect.Hash.String()), slog.String(typeKey, invVect.Type.String()), slog.String(errKey, err.Error()))
}
case wire.InvTypeBlock:
if err = p.peerHandler.HandleBlockAnnouncement(invVect, p); err != nil {
commandLogger.Error("Unable to process block", slog.String(hashKey, invVect.Hash.String()), slog.String(typeKey, invVect.Type.String()), slog.String(errKey, err.Error()))
}
}
}
}(invMsg.InvList, commandLogger)

// Please note that this is the BlockMessage, not the wire.MsgBlock
blockMsg, ok := msg.(*BlockMessage)
if !ok {
commandLogger.Error("Unable to cast block message, calling with generic wire.Message")
err = p.peerHandler.HandleBlock(msg, p)
if err != nil {
commandLogger.Error("Unable to process block message", slog.String(errKey, err.Error()))
}
continue
}
case wire.CmdGetData:
dataMsg, ok := msg.(*wire.MsgGetData)
if !ok {
continue
}
for _, inv := range dataMsg.InvList {
commandLogger.Debug(receivedMsg, slog.String(hashKey, inv.Hash.String()), slog.String(typeKey, inv.Type.String()))
}
p.handleGetDataMsg(dataMsg, commandLogger)

commandLogger.Info(receivedMsg, slog.String(hashKey, blockMsg.Header.BlockHash().String()))
case wire.CmdTx:
txMsg, ok := msg.(*wire.MsgTx)
if !ok {
continue
}
commandLogger.Debug(receivedMsg, slog.String(hashKey, txMsg.TxHash().String()), slog.Int("size", txMsg.SerializeSize()))
if err = p.peerHandler.HandleTransaction(txMsg, p); err != nil {
commandLogger.Error("Unable to process tx", slog.String(hashKey, txMsg.TxHash().String()), slog.String(errKey, err.Error()))
}

err = p.peerHandler.HandleBlock(blockMsg, p)
if err != nil {
commandLogger.Error("Unable to process block", slog.String(hashKey, blockMsg.Header.BlockHash().String()), slog.String(errKey, err.Error()))
}
case wire.CmdBlock:
msgBlock, ok := msg.(*wire.MsgBlock)
if ok {
commandLogger.Info(receivedMsg, slog.String(hashKey, msgBlock.Header.BlockHash().String()))

case wire.CmdReject:
rejMsg, ok := msg.(*wire.MsgReject)
if !ok {
continue
}
if err = p.peerHandler.HandleTransactionRejection(rejMsg, p); err != nil {
commandLogger.Error("Unable to process block", slog.String(hashKey, rejMsg.Hash.String()), slog.String(errKey, err.Error()))
}
err = p.peerHandler.HandleBlock(msgBlock, p)
if err != nil {
commandLogger.Error("Unable to process block", slog.String(hashKey, msgBlock.Header.BlockHash().String()), slog.String(errKey, err.Error()))
}
continue
}

// Please note that this is the BlockMessage, not the wire.MsgBlock
blockMsg, ok := msg.(*BlockMessage)
if !ok {
commandLogger.Error("Unable to cast block message, calling with generic wire.Message")
err = p.peerHandler.HandleBlock(msg, p)
if err != nil {
commandLogger.Error("Unable to process block message", slog.String(errKey, err.Error()))
}
continue
}

case wire.CmdVerAck:
commandLogger.Debug(receivedMsg)
p.receivedVerAck.Store(true)
commandLogger.Info(receivedMsg, slog.String(hashKey, blockMsg.Header.BlockHash().String()))

case wire.CmdPong:
commandLogger.Debug(receivedMsg, slog.String(commandKey, strings.ToUpper(wire.CmdPong)))
p.pingPongAlive <- struct{}{}
err = p.peerHandler.HandleBlock(blockMsg, p)
if err != nil {
commandLogger.Error("Unable to process block", slog.String(hashKey, blockMsg.Header.BlockHash().String()), slog.String(errKey, err.Error()))
}

default:
case wire.CmdReject:
rejMsg, ok := msg.(*wire.MsgReject)
if !ok {
continue
}
if err = p.peerHandler.HandleTransactionRejection(rejMsg, p); err != nil {
commandLogger.Error("Unable to process block", slog.String(hashKey, rejMsg.Hash.String()), slog.String(errKey, err.Error()))
}

case wire.CmdVerAck:
commandLogger.Debug(receivedMsg)
p.receivedVerAck.Store(true)

commandLogger.Debug("command ignored")
case wire.CmdPong:
commandLogger.Debug(receivedMsg, slog.String(commandKey, strings.ToUpper(wire.CmdPong)))
p.pingPongAlive <- struct{}{}

default:
commandLogger.Debug("command ignored")
}
}
}
}
}()
}

func (p *Peer) handleGetDataMsg(dataMsg *wire.MsgGetData, logger *slog.Logger) {
Expand Down Expand Up @@ -665,3 +679,11 @@ func (p *Peer) IsHealthy() bool {

return p.isHealthy
}

func (p *Peer) Shutdown() {
p.mu.Lock()
defer p.mu.Unlock()
if p.quitReadHandler != nil {
p.quitReadHandler <- struct{}{}
}
}
4 changes: 4 additions & 0 deletions PeerManager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ func TestNewPeerManager(t *testing.T) {
peerHandler := NewMockPeerHandler()

peer, err := NewPeer(logger, "localhost:18333", peerHandler, wire.TestNet)
defer peer.Shutdown()
require.NoError(t, err)

err = pm.AddPeer(peer)
Expand All @@ -57,6 +58,7 @@ func TestNewPeerManager(t *testing.T) {
for _, peerAddress := range peers {
peer, _ := NewPeer(logger, peerAddress, peerHandler, wire.TestNet)
_ = pm.AddPeer(peer)
defer peer.Shutdown()
}

assert.Len(t, pm.GetPeers(), 4)
Expand All @@ -73,6 +75,8 @@ func TestAnnounceNewTransaction(t *testing.T) {

peer, _ := NewPeerMock("localhost:18333", peerHandler, wire.TestNet)
err := pm.AddPeer(peer)
defer peer.Shutdown()

require.NoError(t, err)

pm.AnnounceTransaction(tx1Hash, nil)
Expand Down
2 changes: 2 additions & 0 deletions Peer_Mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ func (p *PeerMock) IsHealthy() bool {
return true
}

func (p *PeerMock) Shutdown() {}

func (p *PeerMock) Connected() bool {
return true
}
Expand Down
1 change: 1 addition & 0 deletions interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type PeerI interface {
RequestBlock(blockHash *chainhash.Hash)
Network() wire.BitcoinNet
IsHealthy() bool
Shutdown()
}

type PeerHandlerI interface {
Expand Down

0 comments on commit ca9e440

Please sign in to comment.