Skip to content

Commit

Permalink
Allow setting max bytes of wire reader in peer as option
Browse files Browse the repository at this point in the history
  • Loading branch information
boecklim committed May 30, 2023
1 parent d313b9b commit 822970a
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 121 deletions.
252 changes: 131 additions & 121 deletions Peer.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ var (
pingInterval = 2 * time.Minute
)

const (
defaultMaximumMessageSize = 32 * 1024 * 1024
)

type Block struct {
Hash *chainhash.Hash `json:"hash,omitempty"` // Little endian
PreviousHash *chainhash.Hash `json:"previous_hash,omitempty"` // Little endian
Expand All @@ -36,35 +40,37 @@ type Block struct {
}

type Peer struct {
address string
network wire.BitcoinNet
mu sync.RWMutex
readConn net.Conn
writeConn net.Conn
incomingConn net.Conn
dial func(network, address string) (net.Conn, error)
peerHandler PeerHandlerI
writeChan chan wire.Message
quit chan struct{}
logger utils.Logger
sentVerAck atomic.Bool
receivedVerAck atomic.Bool
batchDelay time.Duration
invBatcher *batcher.Batcher[chainhash.Hash]
dataBatcher *batcher.Batcher[chainhash.Hash]
address string
network wire.BitcoinNet
mu sync.RWMutex
readConn net.Conn
writeConn net.Conn
incomingConn net.Conn
dial func(network, address string) (net.Conn, error)
peerHandler PeerHandlerI
writeChan chan wire.Message
quit chan struct{}
logger utils.Logger
sentVerAck atomic.Bool
receivedVerAck atomic.Bool
batchDelay time.Duration
invBatcher *batcher.Batcher[chainhash.Hash]
dataBatcher *batcher.Batcher[chainhash.Hash]
maximumMessageSize int64
}

// NewPeer returns a new bitcoin peer for the provided address and configuration.
func NewPeer(logger utils.Logger, address string, peerHandler PeerHandlerI, network wire.BitcoinNet, options ...PeerOptions) (*Peer, error) {
writeChan := make(chan wire.Message, 10000)

p := &Peer{
network: network,
address: address,
writeChan: writeChan,
peerHandler: peerHandler,
logger: logger,
dial: net.Dial,
network: network,
address: address,
writeChan: writeChan,
peerHandler: peerHandler,
logger: logger,
dial: net.Dial,
maximumMessageSize: defaultMaximumMessageSize,
}

for _, option := range options {
Expand Down Expand Up @@ -221,127 +227,131 @@ func (p *Peer) String() string {
func (p *Peer) readHandler() {
readConn := p.readConn

if readConn != nil {
reader := bufio.NewReader(&io.LimitedReader{R: readConn, N: 32 * 1024 * 1024})
for {
msg, b, err := wire.ReadMessage(reader, wire.ProtocolVersion, p.network)
if err != nil {
if errors.Is(err, io.EOF) {
p.logger.Errorf(fmt.Sprintf("READ EOF whilst reading from %s [%d bytes], are you on the right network?\n%s", p.address, len(b), string(b)))
p.disconnect()
break
}
p.logger.Errorf("[%s] Failed to read message: %v", p.address, err)
continue
if readConn == nil {
p.logger.Errorf("no connection")
return
}

reader := bufio.NewReader(&io.LimitedReader{R: readConn, N: p.maximumMessageSize})
for {
msg, b, err := wire.ReadMessage(reader, wire.ProtocolVersion, p.network)
if err != nil {
if errors.Is(err, io.EOF) {
p.logger.Errorf(fmt.Sprintf("READ EOF whilst reading from %s [%d bytes], bytes = %s, err = %v", p.address, len(b), string(b), err))
p.disconnect()
break
}
p.logger.Errorf("[%s] Failed to read message: %v", p.address, err)
continue
}

// we could check this based on type (switch msg.(type)) but that would not allow
// us to override the default behaviour for a specific message type
switch msg.Command() {
case wire.CmdVersion:
p.logger.Debugf("[%s] Recv %s", p.address, strings.ToUpper(msg.Command()))
if p.sentVerAck.Load() {
p.logger.Warnf("[%s] Received version message after sending verack", p.address)
continue
}
// we could check this based on type (switch msg.(type)) but that would not allow
// us to override the default behaviour for a specific message type
switch msg.Command() {
case wire.CmdVersion:
p.logger.Debugf("[%s] Recv %s", p.address, strings.ToUpper(msg.Command()))
if p.sentVerAck.Load() {
p.logger.Warnf("[%s] Received version message after sending verack", p.address)
continue
}

verackMsg := wire.NewMsgVerAck()
if err = wire.WriteMessage(readConn, verackMsg, wire.ProtocolVersion, p.network); err != nil {
p.logger.Errorf("[%s] failed to write message: %v", p.address, err)
}
p.logger.Debugf("[%s] Sent %s", p.address, strings.ToUpper(verackMsg.Command()))
p.sentVerAck.Store(true)

case wire.CmdPing:
pingMsg := msg.(*wire.MsgPing)
p.writeChan <- wire.NewMsgPong(pingMsg.Nonce)

case wire.CmdInv:
invMsg := msg.(*wire.MsgInv)
if p.logger.LogLevel() == int(gocore.DEBUG) {
p.logger.Debugf("[%s] Recv INV (%d items)", p.address, len(invMsg.InvList))
for _, inv := range invMsg.InvList {
p.logger.Debugf(" [%s] %s", p.address, inv.Hash.String())
}
verackMsg := wire.NewMsgVerAck()
if err = wire.WriteMessage(readConn, verackMsg, wire.ProtocolVersion, p.network); err != nil {
p.logger.Errorf("[%s] failed to write message: %v", p.address, err)
}
p.logger.Debugf("[%s] Sent %s", p.address, strings.ToUpper(verackMsg.Command()))
p.sentVerAck.Store(true)

case wire.CmdPing:
pingMsg := msg.(*wire.MsgPing)
p.writeChan <- wire.NewMsgPong(pingMsg.Nonce)

case wire.CmdInv:
invMsg := msg.(*wire.MsgInv)
if p.logger.LogLevel() == int(gocore.DEBUG) {
p.logger.Debugf("[%s] Recv INV (%d items)", p.address, len(invMsg.InvList))
for _, inv := range invMsg.InvList {
p.logger.Debugf(" [%s] %s", p.address, inv.Hash.String())
}
}

go func(invList []*wire.InvVect) {
for _, invVect := range invList {
switch invVect.Type {
case wire.InvTypeTx:
if err = p.peerHandler.HandleTransactionAnnouncement(invVect, p); err != nil {
p.logger.Errorf("[%s] Unable to process tx %s: %v", p.address, invVect.Hash.String(), err)
}
case wire.InvTypeBlock:
if err = p.peerHandler.HandleBlockAnnouncement(invVect, p); err != nil {
p.logger.Errorf("[%s] Unable to process block %s: %v", p.address, invVect.Hash.String(), err)
}
go func(invList []*wire.InvVect) {
for _, invVect := range invList {
switch invVect.Type {
case wire.InvTypeTx:
if err = p.peerHandler.HandleTransactionAnnouncement(invVect, p); err != nil {
p.logger.Errorf("[%s] Unable to process tx %s: %v", p.address, invVect.Hash.String(), err)
}
case wire.InvTypeBlock:
if err = p.peerHandler.HandleBlockAnnouncement(invVect, p); err != nil {
p.logger.Errorf("[%s] Unable to process block %s: %v", p.address, invVect.Hash.String(), err)
}
}
}(invMsg.InvList)

case wire.CmdGetData:
dataMsg := msg.(*wire.MsgGetData)
p.logger.Infof("[%s] Recv GETDATA (%d items)", p.address, len(dataMsg.InvList))
if p.logger.LogLevel() == int(gocore.DEBUG) {
for _, inv := range dataMsg.InvList {
p.logger.Debugf(" [%s] %s", p.address, inv.Hash.String())
}
}
p.handleGetDataMsg(dataMsg)

case wire.CmdTx:
txMsg := msg.(*wire.MsgTx)
p.logger.Debugf("Recv TX %s (%d bytes)", txMsg.TxHash().String(), txMsg.SerializeSize())
if err = p.peerHandler.HandleTransaction(txMsg, p); err != nil {
p.logger.Errorf("Unable to process tx %s: %v", txMsg.TxHash().String(), err)
}(invMsg.InvList)

case wire.CmdGetData:
dataMsg := msg.(*wire.MsgGetData)
p.logger.Infof("[%s] Recv GETDATA (%d items)", p.address, len(dataMsg.InvList))
if p.logger.LogLevel() == int(gocore.DEBUG) {
for _, inv := range dataMsg.InvList {
p.logger.Debugf(" [%s] %s", p.address, inv.Hash.String())
}
}
p.handleGetDataMsg(dataMsg)

case wire.CmdBlock:
msgBlock, ok := msg.(*wire.MsgBlock)
if ok {
p.logger.Infof("[%s] Recv %s: %s", p.address, strings.ToUpper(msg.Command()), msgBlock.Header.BlockHash().String())
case wire.CmdTx:
txMsg := msg.(*wire.MsgTx)
p.logger.Debugf("Recv TX %s (%d bytes)", txMsg.TxHash().String(), txMsg.SerializeSize())
if err = p.peerHandler.HandleTransaction(txMsg, p); err != nil {
p.logger.Errorf("Unable to process tx %s: %v", txMsg.TxHash().String(), err)
}

err = p.peerHandler.HandleBlock(msgBlock, p)
if err != nil {
p.logger.Errorf("[%s] Unable to process block %s: %v", p.address, msgBlock.Header.BlockHash().String(), err)
}
continue
}
case wire.CmdBlock:
msgBlock, ok := msg.(*wire.MsgBlock)
if ok {
p.logger.Infof("[%s] Recv %s: %s", p.address, strings.ToUpper(msg.Command()), msgBlock.Header.BlockHash().String())

// Please note that this is the BlockMessage, not the wire.MsgBlock
blockMsg, ok := msg.(*BlockMessage)
if !ok {
p.logger.Errorf("Unable to cast block message, calling with generic wire.Message")
err = p.peerHandler.HandleBlock(msg, p)
if err != nil {
p.logger.Errorf("[%s] Unable to process block message: %v", p.address, err)
}
continue
err = p.peerHandler.HandleBlock(msgBlock, p)
if err != nil {
p.logger.Errorf("[%s] Unable to process block %s: %v", p.address, msgBlock.Header.BlockHash().String(), err)
}
continue
}

p.logger.Infof("[%s] Recv %s: %s", p.address, strings.ToUpper(msg.Command()), blockMsg.Header.BlockHash().String())

err = p.peerHandler.HandleBlock(blockMsg, p)
// Please note that this is the BlockMessage, not the wire.MsgBlock
blockMsg, ok := msg.(*BlockMessage)
if !ok {
p.logger.Errorf("Unable to cast block message, calling with generic wire.Message")
err = p.peerHandler.HandleBlock(msg, p)
if err != nil {
p.logger.Errorf("[%s] Unable to process block %s: %v", p.address, blockMsg.Header.BlockHash().String(), err)
p.logger.Errorf("[%s] Unable to process block message: %v", p.address, err)
}
continue
}

case wire.CmdReject:
rejMsg := msg.(*wire.MsgReject)
if err = p.peerHandler.HandleTransactionRejection(rejMsg, p); err != nil {
p.logger.Errorf("[%s] Unable to process block %s: %v", p.address, rejMsg.Hash.String(), err)
}
p.logger.Infof("[%s] Recv %s: %s", p.address, strings.ToUpper(msg.Command()), blockMsg.Header.BlockHash().String())

case wire.CmdVerAck:
p.logger.Debugf("[%s] Recv %s", p.address, strings.ToUpper(msg.Command()))
p.receivedVerAck.Store(true)
err = p.peerHandler.HandleBlock(blockMsg, p)
if err != nil {
p.logger.Errorf("[%s] Unable to process block %s: %v", p.address, blockMsg.Header.BlockHash().String(), err)
}

default:
p.logger.Debugf("[%s] Ignored %s", p.address, strings.ToUpper(msg.Command()))
case wire.CmdReject:
rejMsg := msg.(*wire.MsgReject)
if err = p.peerHandler.HandleTransactionRejection(rejMsg, p); err != nil {
p.logger.Errorf("[%s] Unable to process block %s: %v", p.address, rejMsg.Hash.String(), err)
}

case wire.CmdVerAck:
p.logger.Debugf("[%s] Recv %s", p.address, strings.ToUpper(msg.Command()))
p.receivedVerAck.Store(true)

default:
p.logger.Debugf("[%s] Ignored %s", p.address, strings.ToUpper(msg.Command()))
}
}

}

func (p *Peer) handleGetDataMsg(dataMsg *wire.MsgGetData) {
Expand Down
6 changes: 6 additions & 0 deletions Peer_Options.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,9 @@ func WithIncomingConnection(conn net.Conn) PeerOptions {
p.incomingConn = conn
}
}

func WithMaximumMessageSize(maximumMessageSize int64) PeerOptions {
return func(p *Peer) {
p.maximumMessageSize = maximumMessageSize
}
}

0 comments on commit 822970a

Please sign in to comment.