diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index 4ff2746a73..5f5d9eac70 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -2,6 +2,7 @@ package htlcswitch import ( "bytes" + "context" "errors" "fmt" "math/rand" @@ -245,8 +246,14 @@ type Switch struct { // This will be retrieved by the registered links atomically. bestHeight uint32 - wg sync.WaitGroup - quit chan struct{} + // TODO(yy): remove handleLocalResponseWG, once handleLocalResponse runs + // without a goroutine. Currently we can't run handleLocalResponse in + // gm, since if gm is stopping, the goroutine won't start and it is + // unclear if it safe to skip handleLocalResponse. + handleLocalResponseWG sync.WaitGroup + + // gm starts and stops tasks in goroutines and waits for them. + gm *fn.GoroutineManager // cfg is a copy of the configuration struct that the htlc switch // service was initialized with. @@ -370,6 +377,7 @@ func New(cfg Config, currentHeight uint32) (*Switch, error) { s := &Switch{ bestHeight: currentHeight, + gm: fn.NewGoroutineManager(context.Background()), cfg: &cfg, circuits: circuitMap, linkIndex: make(map[lnwire.ChannelID]ChannelLink), @@ -382,7 +390,6 @@ func New(cfg Config, currentHeight uint32) (*Switch, error) { chanCloseRequests: make(chan *ChanClose), resolutionMsgs: make(chan *resolutionMsg), resMsgStore: resStore, - quit: make(chan struct{}), } s.aliasToReal = make(map[lnwire.ShortChannelID]lnwire.ShortChannelID) @@ -420,14 +427,14 @@ func (s *Switch) ProcessContractResolution(msg contractcourt.ResolutionMsg) erro ResolutionMsg: msg, errChan: errChan, }: - case <-s.quit: + case <-s.gm.Done(): return ErrSwitchExiting } select { case err := <-errChan: return err - case <-s.quit: + case <-s.gm.Done(): return ErrSwitchExiting } } @@ -493,14 +500,11 @@ func (s *Switch) GetAttemptResult(attemptID uint64, paymentHash lntypes.Hash, // Since the attempt was known, we can start a goroutine that can // extract the result when it is available, and pass it on to the // caller. - s.wg.Add(1) - go func() { - defer s.wg.Done() - + err = s.gm.Go(func(ctx context.Context) { var n *networkResult select { case n = <-nChan: - case <-s.quit: + case <-ctx.Done(): // We close the result channel to signal a shutdown. We // don't send any result in this case since the HTLC is // still in flight. @@ -524,7 +528,12 @@ func (s *Switch) GetAttemptResult(attemptID uint64, paymentHash lntypes.Hash, return } resultChan <- result - }() + }) + + // The switch shutting down is signaled by closing the channel. + if err != nil { + close(resultChan) + } return resultChan, nil } @@ -704,12 +713,18 @@ func (s *Switch) ForwardPackets(linkQuit chan struct{}, select { case <-linkQuit: return nil - case <-s.quit: + + case <-s.gm.Done(): return nil + default: // Spawn a goroutine to log the errors returned from failed packets. - s.wg.Add(1) - go s.logFwdErrs(&numSent, &wg, fwdChan) + err := s.gm.Go(func(ctx context.Context) { + s.logFwdErrs(ctx, &numSent, &wg, fwdChan) + }) + if err != nil { + return nil + } } // Make a first pass over the packets, forwarding any settles or fails. @@ -820,8 +835,8 @@ func (s *Switch) ForwardPackets(linkQuit chan struct{}, } // logFwdErrs logs any errors received on `fwdChan`. -func (s *Switch) logFwdErrs(num *int, wg *sync.WaitGroup, fwdChan chan error) { - defer s.wg.Done() +func (s *Switch) logFwdErrs(ctx context.Context, num *int, wg *sync.WaitGroup, + fwdChan chan error) { // Wait here until the outer function has finished persisting // and routing the packets. This guarantees we don't read from num until @@ -836,7 +851,8 @@ func (s *Switch) logFwdErrs(num *int, wg *sync.WaitGroup, fwdChan chan error) { log.Errorf("Unhandled error while reforwarding htlc "+ "settle/fail over htlcswitch: %v", err) } - case <-s.quit: + + case <-ctx.Done(): log.Errorf("unable to forward htlc packet " + "htlc switch was stopped") return @@ -862,7 +878,7 @@ func (s *Switch) routeAsync(packet *htlcPacket, errChan chan error, return nil case <-linkQuit: return ErrLinkShuttingDown - case <-s.quit: + case <-s.gm.Done(): return errors.New("htlc switch was stopped") } } @@ -939,7 +955,7 @@ func (s *Switch) getLocalLink(pkt *htlcPacket, htlc *lnwire.UpdateAddHTLC) ( // // NOTE: This method MUST be spawned as a goroutine. func (s *Switch) handleLocalResponse(pkt *htlcPacket) { - defer s.wg.Done() + defer s.handleLocalResponseWG.Done() attemptID := pkt.incomingHTLCID @@ -1435,7 +1451,7 @@ func (s *Switch) CloseLink(chanPoint *wire.OutPoint, case s.chanCloseRequests <- command: return updateChan, errChan - case <-s.quit: + case <-s.gm.Done(): errChan <- ErrSwitchExiting close(updateChan) return updateChan, errChan @@ -1453,8 +1469,6 @@ func (s *Switch) CloseLink(chanPoint *wire.OutPoint, // // NOTE: This MUST be run as a goroutine. func (s *Switch) htlcForwarder() { - defer s.wg.Done() - defer func() { s.blockEpochStream.Cancel() @@ -1488,6 +1502,8 @@ func (s *Switch) htlcForwarder() { var wg sync.WaitGroup for _, link := range linksToStop { wg.Add(1) + // Here it is ok to start a goroutine directly bypassing + // s.gm, because we want for them to complete here. go func(l ChannelLink) { defer wg.Done() @@ -1627,15 +1643,16 @@ out: // collect all the forwarding events since the last internal, // and write them out to our log. case <-s.cfg.FwdEventTicker.Ticks(): - s.wg.Add(1) - go func() { - defer s.wg.Done() - - if err := s.FlushForwardingEvents(); err != nil { + err := s.gm.Go(func(ctx context.Context) { + err := s.FlushForwardingEvents() + if err != nil { log.Errorf("unable to flush "+ "forwarding events: %v", err) } - }() + }) + if err != nil { + return + } // The log ticker has fired, so we'll calculate some forwarding // stats for the last 10 seconds to display within the logs to @@ -1738,7 +1755,7 @@ out: // memory. s.pendingSettleFails = s.pendingSettleFails[:0] - case <-s.quit: + case <-s.gm.Done(): return } } @@ -1759,8 +1776,15 @@ func (s *Switch) Start() error { } s.blockEpochStream = blockEpochStream - s.wg.Add(1) - go s.htlcForwarder() + err = s.gm.Go(func(ctx context.Context) { + s.htlcForwarder() + }) + if err != nil { + s.Stop() + err = fmt.Errorf("unable to start htlc forwarder: %w", err) + log.Errorf("%v", err) + return err + } if err := s.reforwardResponses(); err != nil { s.Stop() @@ -1990,9 +2014,11 @@ func (s *Switch) Stop() error { log.Info("HTLC Switch shutting down...") defer log.Debug("HTLC Switch shutdown complete") - close(s.quit) + // Ask running goroutines to stop and wait for them. + s.gm.Stop() - s.wg.Wait() + // TODO(yy): remove this, when s.handleLocalResponseWG is removed. + s.handleLocalResponseWG.Wait() // Wait until all active goroutines have finished exiting before // stopping the mailboxes, otherwise the mailbox map could still be @@ -2348,7 +2374,7 @@ func (s *Switch) RemoveLink(chanID lnwire.ChannelID) { select { case <-stopChan: return - case <-s.quit: + case <-s.gm.Done(): return } } @@ -3020,7 +3046,7 @@ func (s *Switch) handlePacketSettle(packet *htlcPacket) error { // NOTE: `closeCircuit` modifies the state of `packet`. if localHTLC { // TODO(yy): remove the goroutine and send back the error here. - s.wg.Add(1) + s.handleLocalResponseWG.Add(1) go s.handleLocalResponse(packet) // If this is a locally initiated HTLC, there's no need to @@ -3076,7 +3102,7 @@ func (s *Switch) handlePacketFail(packet *htlcPacket, // NOTE: `closeCircuit` modifies the state of `packet`. if packet.incomingChanID == hop.Source { // TODO(yy): remove the goroutine and send back the error here. - s.wg.Add(1) + s.handleLocalResponseWG.Add(1) go s.handleLocalResponse(packet) // If this is a locally initiated HTLC, there's no need to diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index 825ee6c652..11d0bbd3cc 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -8,6 +8,7 @@ import ( "io" mrand "math/rand" "reflect" + "sync" "testing" "time" @@ -3180,6 +3181,54 @@ func TestSwitchGetAttemptResult(t *testing.T) { } } +func TestSwitchGetAttemptResultStress(t *testing.T) { + t.Parallel() + + const paymentID = 123 + var preimg lntypes.Preimage + preimg[0] = 3 + + s, err := initSwitchWithTempDB(t, testStartingHeight) + require.NoError(t, err, "unable to init switch") + if err := s.Start(); err != nil { + t.Fatalf("unable to start switch: %v", err) + } + defer s.Stop() + + lookup := make(chan *PaymentCircuit, 1) + s.circuits = &mockCircuitMap{ + lookup: lookup, + } + + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + + for range 1000 { + // Next let the lookup find the circuit in the circuit map. It should + // subscribe to payment results, and return the result when available. + lookup <- &PaymentCircuit{} + _, err := s.GetAttemptResult( + paymentID, lntypes.Hash{}, newMockDeobfuscator(), + ) + require.NoError(t, err, "unable to get payment result") + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + + time.Sleep(10 * time.Millisecond) + + s.Stop() + }() + + wg.Wait() +} + // TestInvalidFailure tests that the switch returns an unreadable failure error // if the failure cannot be decrypted. func TestInvalidFailure(t *testing.T) { @@ -4985,7 +5034,7 @@ func testSwitchForwardFailAlias(t *testing.T, zeroConf bool) { // Pull packet from Bob's link, and do nothing with it. select { case <-bobLink.packets: - case <-s.quit: + case <-s.gm.Done(): t.Fatal("switch shutting down, failed to forward packet") } @@ -5046,7 +5095,8 @@ func testSwitchForwardFailAlias(t *testing.T, zeroConf bool) { failMsg, ok := msg.(*lnwire.FailTemporaryChannelFailure) require.True(t, ok) require.Equal(t, aliceAlias, failMsg.Update.ShortChannelID) - case <-s2.quit: + + case <-s2.gm.Done(): t.Fatal("switch shutting down, failed to forward packet") } } @@ -5229,7 +5279,8 @@ func testSwitchAliasFailAdd(t *testing.T, zeroConf, private, useAlias bool) { failMsg, ok := msg.(*lnwire.FailTemporaryChannelFailure) require.True(t, ok) require.Equal(t, outgoingChanID, failMsg.Update.ShortChannelID) - case <-s.quit: + + case <-s.gm.Done(): t.Fatal("switch shutting down, failed to receive fail packet") } } @@ -5429,7 +5480,8 @@ func testSwitchHandlePacketForward(t *testing.T, zeroConf, private, failMsg, ok := msg.(*lnwire.FailAmountBelowMinimum) require.True(t, ok) require.Equal(t, outgoingChanID, failMsg.Update.ShortChannelID) - case <-s.quit: + + case <-s.gm.Done(): t.Fatal("switch shutting down, failed to receive failure") } } @@ -5587,7 +5639,7 @@ func testSwitchAliasInterceptFail(t *testing.T, zeroConf bool) { isAlias := failScid == aliceAlias || failScid == aliceAlias2 require.True(t, isAlias) - case <-s.quit: + case <-s.gm.Done(): t.Fatalf("switch shutting down, failed to receive failure") }