diff --git a/Peer.go b/Peer.go index af0a012..e1d453d 100644 --- a/Peer.go +++ b/Peer.go @@ -26,6 +26,10 @@ var ( pingInterval = 2 * time.Minute ) +const ( + defaultMaximumMessageSize = 32 * 1024 * 1024 +) + type Block struct { Hash *chainhash.Hash `json:"hash,omitempty"` // Little endian PreviousHash *chainhash.Hash `json:"previous_hash,omitempty"` // Little endian @@ -36,22 +40,23 @@ type Block struct { } type Peer struct { - address string - network wire.BitcoinNet - mu sync.RWMutex - readConn net.Conn - writeConn net.Conn - incomingConn net.Conn - dial func(network, address string) (net.Conn, error) - peerHandler PeerHandlerI - writeChan chan wire.Message - quit chan struct{} - logger utils.Logger - sentVerAck atomic.Bool - receivedVerAck atomic.Bool - batchDelay time.Duration - invBatcher *batcher.Batcher[chainhash.Hash] - dataBatcher *batcher.Batcher[chainhash.Hash] + address string + network wire.BitcoinNet + mu sync.RWMutex + readConn net.Conn + writeConn net.Conn + incomingConn net.Conn + dial func(network, address string) (net.Conn, error) + peerHandler PeerHandlerI + writeChan chan wire.Message + quit chan struct{} + logger utils.Logger + sentVerAck atomic.Bool + receivedVerAck atomic.Bool + batchDelay time.Duration + invBatcher *batcher.Batcher[chainhash.Hash] + dataBatcher *batcher.Batcher[chainhash.Hash] + maximumMessageSize int64 } // NewPeer returns a new bitcoin peer for the provided address and configuration. @@ -59,12 +64,13 @@ func NewPeer(logger utils.Logger, address string, peerHandler PeerHandlerI, netw writeChan := make(chan wire.Message, 10000) p := &Peer{ - network: network, - address: address, - writeChan: writeChan, - peerHandler: peerHandler, - logger: logger, - dial: net.Dial, + network: network, + address: address, + writeChan: writeChan, + peerHandler: peerHandler, + logger: logger, + dial: net.Dial, + maximumMessageSize: defaultMaximumMessageSize, } for _, option := range options { @@ -221,127 +227,131 @@ func (p *Peer) String() string { func (p *Peer) readHandler() { readConn := p.readConn - if readConn != nil { - reader := bufio.NewReader(&io.LimitedReader{R: readConn, N: 32 * 1024 * 1024}) - for { - msg, b, err := wire.ReadMessage(reader, wire.ProtocolVersion, p.network) - if err != nil { - if errors.Is(err, io.EOF) { - p.logger.Errorf(fmt.Sprintf("READ EOF whilst reading from %s [%d bytes], are you on the right network?\n%s", p.address, len(b), string(b))) - p.disconnect() - break - } - p.logger.Errorf("[%s] Failed to read message: %v", p.address, err) - continue + if readConn == nil { + p.logger.Errorf("no connection") + return + } + + reader := bufio.NewReader(&io.LimitedReader{R: readConn, N: p.maximumMessageSize}) + for { + msg, b, err := wire.ReadMessage(reader, wire.ProtocolVersion, p.network) + if err != nil { + if errors.Is(err, io.EOF) { + p.logger.Errorf(fmt.Sprintf("READ EOF whilst reading from %s [%d bytes], bytes = %s, err = %v", p.address, len(b), string(b), err)) + p.disconnect() + break } + p.logger.Errorf("[%s] Failed to read message: %v", p.address, err) + continue + } - // we could check this based on type (switch msg.(type)) but that would not allow - // us to override the default behaviour for a specific message type - switch msg.Command() { - case wire.CmdVersion: - p.logger.Debugf("[%s] Recv %s", p.address, strings.ToUpper(msg.Command())) - if p.sentVerAck.Load() { - p.logger.Warnf("[%s] Received version message after sending verack", p.address) - continue - } + // we could check this based on type (switch msg.(type)) but that would not allow + // us to override the default behaviour for a specific message type + switch msg.Command() { + case wire.CmdVersion: + p.logger.Debugf("[%s] Recv %s", p.address, strings.ToUpper(msg.Command())) + if p.sentVerAck.Load() { + p.logger.Warnf("[%s] Received version message after sending verack", p.address) + continue + } - verackMsg := wire.NewMsgVerAck() - if err = wire.WriteMessage(readConn, verackMsg, wire.ProtocolVersion, p.network); err != nil { - p.logger.Errorf("[%s] failed to write message: %v", p.address, err) - } - p.logger.Debugf("[%s] Sent %s", p.address, strings.ToUpper(verackMsg.Command())) - p.sentVerAck.Store(true) - - case wire.CmdPing: - pingMsg := msg.(*wire.MsgPing) - p.writeChan <- wire.NewMsgPong(pingMsg.Nonce) - - case wire.CmdInv: - invMsg := msg.(*wire.MsgInv) - if p.logger.LogLevel() == int(gocore.DEBUG) { - p.logger.Debugf("[%s] Recv INV (%d items)", p.address, len(invMsg.InvList)) - for _, inv := range invMsg.InvList { - p.logger.Debugf(" [%s] %s", p.address, inv.Hash.String()) - } + verackMsg := wire.NewMsgVerAck() + if err = wire.WriteMessage(readConn, verackMsg, wire.ProtocolVersion, p.network); err != nil { + p.logger.Errorf("[%s] failed to write message: %v", p.address, err) + } + p.logger.Debugf("[%s] Sent %s", p.address, strings.ToUpper(verackMsg.Command())) + p.sentVerAck.Store(true) + + case wire.CmdPing: + pingMsg := msg.(*wire.MsgPing) + p.writeChan <- wire.NewMsgPong(pingMsg.Nonce) + + case wire.CmdInv: + invMsg := msg.(*wire.MsgInv) + if p.logger.LogLevel() == int(gocore.DEBUG) { + p.logger.Debugf("[%s] Recv INV (%d items)", p.address, len(invMsg.InvList)) + for _, inv := range invMsg.InvList { + p.logger.Debugf(" [%s] %s", p.address, inv.Hash.String()) } + } - go func(invList []*wire.InvVect) { - for _, invVect := range invList { - switch invVect.Type { - case wire.InvTypeTx: - if err = p.peerHandler.HandleTransactionAnnouncement(invVect, p); err != nil { - p.logger.Errorf("[%s] Unable to process tx %s: %v", p.address, invVect.Hash.String(), err) - } - case wire.InvTypeBlock: - if err = p.peerHandler.HandleBlockAnnouncement(invVect, p); err != nil { - p.logger.Errorf("[%s] Unable to process block %s: %v", p.address, invVect.Hash.String(), err) - } + go func(invList []*wire.InvVect) { + for _, invVect := range invList { + switch invVect.Type { + case wire.InvTypeTx: + if err = p.peerHandler.HandleTransactionAnnouncement(invVect, p); err != nil { + p.logger.Errorf("[%s] Unable to process tx %s: %v", p.address, invVect.Hash.String(), err) + } + case wire.InvTypeBlock: + if err = p.peerHandler.HandleBlockAnnouncement(invVect, p); err != nil { + p.logger.Errorf("[%s] Unable to process block %s: %v", p.address, invVect.Hash.String(), err) } - } - }(invMsg.InvList) - - case wire.CmdGetData: - dataMsg := msg.(*wire.MsgGetData) - p.logger.Infof("[%s] Recv GETDATA (%d items)", p.address, len(dataMsg.InvList)) - if p.logger.LogLevel() == int(gocore.DEBUG) { - for _, inv := range dataMsg.InvList { - p.logger.Debugf(" [%s] %s", p.address, inv.Hash.String()) } } - p.handleGetDataMsg(dataMsg) - - case wire.CmdTx: - txMsg := msg.(*wire.MsgTx) - p.logger.Debugf("Recv TX %s (%d bytes)", txMsg.TxHash().String(), txMsg.SerializeSize()) - if err = p.peerHandler.HandleTransaction(txMsg, p); err != nil { - p.logger.Errorf("Unable to process tx %s: %v", txMsg.TxHash().String(), err) + }(invMsg.InvList) + + case wire.CmdGetData: + dataMsg := msg.(*wire.MsgGetData) + p.logger.Infof("[%s] Recv GETDATA (%d items)", p.address, len(dataMsg.InvList)) + if p.logger.LogLevel() == int(gocore.DEBUG) { + for _, inv := range dataMsg.InvList { + p.logger.Debugf(" [%s] %s", p.address, inv.Hash.String()) } + } + p.handleGetDataMsg(dataMsg) - case wire.CmdBlock: - msgBlock, ok := msg.(*wire.MsgBlock) - if ok { - p.logger.Infof("[%s] Recv %s: %s", p.address, strings.ToUpper(msg.Command()), msgBlock.Header.BlockHash().String()) + case wire.CmdTx: + txMsg := msg.(*wire.MsgTx) + p.logger.Debugf("Recv TX %s (%d bytes)", txMsg.TxHash().String(), txMsg.SerializeSize()) + if err = p.peerHandler.HandleTransaction(txMsg, p); err != nil { + p.logger.Errorf("Unable to process tx %s: %v", txMsg.TxHash().String(), err) + } - err = p.peerHandler.HandleBlock(msgBlock, p) - if err != nil { - p.logger.Errorf("[%s] Unable to process block %s: %v", p.address, msgBlock.Header.BlockHash().String(), err) - } - continue - } + case wire.CmdBlock: + msgBlock, ok := msg.(*wire.MsgBlock) + if ok { + p.logger.Infof("[%s] Recv %s: %s", p.address, strings.ToUpper(msg.Command()), msgBlock.Header.BlockHash().String()) - // Please note that this is the BlockMessage, not the wire.MsgBlock - blockMsg, ok := msg.(*BlockMessage) - if !ok { - p.logger.Errorf("Unable to cast block message, calling with generic wire.Message") - err = p.peerHandler.HandleBlock(msg, p) - if err != nil { - p.logger.Errorf("[%s] Unable to process block message: %v", p.address, err) - } - continue + err = p.peerHandler.HandleBlock(msgBlock, p) + if err != nil { + p.logger.Errorf("[%s] Unable to process block %s: %v", p.address, msgBlock.Header.BlockHash().String(), err) } + continue + } - p.logger.Infof("[%s] Recv %s: %s", p.address, strings.ToUpper(msg.Command()), blockMsg.Header.BlockHash().String()) - - err = p.peerHandler.HandleBlock(blockMsg, p) + // Please note that this is the BlockMessage, not the wire.MsgBlock + blockMsg, ok := msg.(*BlockMessage) + if !ok { + p.logger.Errorf("Unable to cast block message, calling with generic wire.Message") + err = p.peerHandler.HandleBlock(msg, p) if err != nil { - p.logger.Errorf("[%s] Unable to process block %s: %v", p.address, blockMsg.Header.BlockHash().String(), err) + p.logger.Errorf("[%s] Unable to process block message: %v", p.address, err) } + continue + } - case wire.CmdReject: - rejMsg := msg.(*wire.MsgReject) - if err = p.peerHandler.HandleTransactionRejection(rejMsg, p); err != nil { - p.logger.Errorf("[%s] Unable to process block %s: %v", p.address, rejMsg.Hash.String(), err) - } + p.logger.Infof("[%s] Recv %s: %s", p.address, strings.ToUpper(msg.Command()), blockMsg.Header.BlockHash().String()) - case wire.CmdVerAck: - p.logger.Debugf("[%s] Recv %s", p.address, strings.ToUpper(msg.Command())) - p.receivedVerAck.Store(true) + err = p.peerHandler.HandleBlock(blockMsg, p) + if err != nil { + p.logger.Errorf("[%s] Unable to process block %s: %v", p.address, blockMsg.Header.BlockHash().String(), err) + } - default: - p.logger.Debugf("[%s] Ignored %s", p.address, strings.ToUpper(msg.Command())) + case wire.CmdReject: + rejMsg := msg.(*wire.MsgReject) + if err = p.peerHandler.HandleTransactionRejection(rejMsg, p); err != nil { + p.logger.Errorf("[%s] Unable to process block %s: %v", p.address, rejMsg.Hash.String(), err) } + + case wire.CmdVerAck: + p.logger.Debugf("[%s] Recv %s", p.address, strings.ToUpper(msg.Command())) + p.receivedVerAck.Store(true) + + default: + p.logger.Debugf("[%s] Ignored %s", p.address, strings.ToUpper(msg.Command())) } } + } func (p *Peer) handleGetDataMsg(dataMsg *wire.MsgGetData) { diff --git a/Peer_Options.go b/Peer_Options.go index aa030d7..a1adb98 100644 --- a/Peer_Options.go +++ b/Peer_Options.go @@ -24,3 +24,9 @@ func WithIncomingConnection(conn net.Conn) PeerOptions { p.incomingConn = conn } } + +func WithMaximumMessageSize(maximumMessageSize int64) PeerOptions { + return func(p *Peer) { + p.maximumMessageSize = maximumMessageSize + } +}