diff --git a/peer_manager.go b/peer_manager.go index a54dfed..cf634eb 100644 --- a/peer_manager.go +++ b/peer_manager.go @@ -3,6 +3,7 @@ package p2p import ( "context" "log/slog" + "math/rand" "sort" "sync" "time" @@ -128,23 +129,17 @@ func (pm *PeerManager) AnnounceTransaction(txHash *chainhash.Hash, peers []PeerI } func (pm *PeerManager) RequestTransaction(txHash *chainhash.Hash) PeerI { - // send to the first found peer that is connected - var sendToPeer PeerI - for _, peer := range pm.GetAnnouncedPeers() { - if peer.Connected() { - sendToPeer = peer - break - } - } + pm.mu.RLock() + defer pm.mu.RUnlock() - // we don't have any connected peers - if sendToPeer == nil { + peer := pm.GetRandomConnectedPeer() + if peer == nil { return nil } - sendToPeer.RequestTransaction(txHash) + peer.RequestTransaction(txHash) - return sendToPeer + return peer } func (pm *PeerManager) AnnounceBlock(blockHash *chainhash.Hash, peers []PeerI) []PeerI { @@ -179,6 +174,24 @@ func (pm *PeerManager) RequestBlock(blockHash *chainhash.Hash) PeerI { return sendToPeer } +func (pm *PeerManager) GetRandomConnectedPeer() PeerI { + pm.mu.RLock() + defer pm.mu.RUnlock() + + connectedPeers := make([]PeerI, 0, len(pm.peers)) + for _, peer := range pm.peers { + if peer.Connected() { + connectedPeers = append(connectedPeers, peer) + } + } + + if len(connectedPeers) == 0 { + return nil + } + + return connectedPeers[rand.Intn(len(connectedPeers))] +} + func (pm *PeerManager) GetAnnouncedPeers() []PeerI { pm.mu.RLock() defer pm.mu.RUnlock()