diff --git a/peer.go b/peer.go index 4b50e97..d2edf4e 100644 --- a/peer.go +++ b/peer.go @@ -79,7 +79,7 @@ type Peer struct { cancelReadHandler context.CancelFunc cancelWriteHandler context.CancelFunc - cancelAll context.CancelFunc + cancelMonitoring context.CancelFunc readerWg *sync.WaitGroup writerWg *sync.WaitGroup @@ -125,10 +125,8 @@ func NewPeer(logger *slog.Logger, address string, peerHandler PeerHandlerI, netw } } - ctx := context.Background() - - cancelCtx, cancelAll := context.WithCancel(ctx) - p.cancelAll = cancelAll + cancelCtx, cancelAll := context.WithCancel(context.Background()) + p.cancelMonitoring = cancelAll p.ctx = cancelCtx p.healthMonitorWg.Add(1) @@ -141,19 +139,19 @@ func NewPeer(logger *slog.Logger, address string, peerHandler PeerHandlerI, netw p.dataBatcher = batcher.New(500, p.batchDelay, p.sendDataBatch, true) if p.incomingConn != nil { - go func(funcCtx context.Context) { - connectErr := p.connectAndStartReadWriteHandlers(funcCtx) + go func() { + connectErr := p.connectAndStartReadWriteHandlers() if connectErr != nil { p.logger.Warn("Failed to connect to peer", slog.String(errKey, err.Error())) } - }(ctx) + }() p.logger.Info("Incoming connection from peer") return p, nil } // reconnect if disconnected, but only on outgoing connections p.reconnectingWg.Add(1) - go p.reconnect(ctx) + go p.reconnect() return p, nil } @@ -164,11 +162,11 @@ func (p *Peer) disconnectLock() { p.disconnect() } -func (p *Peer) reconnect(ctx context.Context) { +func (p *Peer) reconnect() { defer func() { p.reconnectingWg.Done() }() - connectErr := p.connectAndStartReadWriteHandlers(ctx) + connectErr := p.connectAndStartReadWriteHandlers() if connectErr != nil { p.logger.Warn("Failed to connect to peer", slog.String(errKey, connectErr.Error())) } @@ -181,7 +179,7 @@ func (p *Peer) reconnect(ctx context.Context) { } p.logger.Info("Reconnecting") - connectErr = p.connectAndStartReadWriteHandlers(ctx) + connectErr = p.connectAndStartReadWriteHandlers() if connectErr != nil { p.logger.Warn("Failed to connect to peer", slog.String(errKey, connectErr.Error())) continue @@ -203,7 +201,7 @@ func (p *Peer) disconnect() { p.receivedVerAck.Store(false) } -func (p *Peer) connectAndStartReadWriteHandlers(ctx context.Context) error { +func (p *Peer) connectAndStartReadWriteHandlers() error { p.mu.Lock() defer p.mu.Unlock() @@ -230,6 +228,8 @@ func (p *Peer) connectAndStartReadWriteHandlers(ctx context.Context) error { p.readConn = conn } + ctx := context.Background() + writerCtx, cancelWriter := context.WithCancel(ctx) p.cancelWriteHandler = cancelWriter for i := 0; i < 10; i++ { @@ -847,7 +847,7 @@ func (p *Peer) stopWriteHandler() { } func (p *Peer) Shutdown() { - p.cancelAll() + p.cancelMonitoring() p.reconnectingWg.Wait() p.healthMonitorWg.Wait() p.pingHandlerWg.Wait()