diff --git a/docs/release-notes/release-notes-0.19.0.md b/docs/release-notes/release-notes-0.19.0.md index 0b576fc7de..dbf8dc31c9 100644 --- a/docs/release-notes/release-notes-0.19.0.md +++ b/docs/release-notes/release-notes-0.19.0.md @@ -63,6 +63,9 @@ * [Fixed a bug](https://github.com/lightningnetwork/lnd/pull/9322) that caused estimateroutefee to ignore the default payment timeout. +* [Fixed a possible crash of htlcswitch upon shutdown](https://github.com/lightningnetwork/lnd/pull/9140) + caused by a race condition in goroutines tracking mechanism. + # New Features * [Support](https://github.com/lightningnetwork/lnd/pull/8390) for diff --git a/go.mod b/go.mod index 36766d0d81..e2495594c4 100644 --- a/go.mod +++ b/go.mod @@ -36,7 +36,7 @@ require ( github.com/lightningnetwork/lightning-onion v1.2.1-0.20240712235311-98bd56499dfb github.com/lightningnetwork/lnd/cert v1.2.2 github.com/lightningnetwork/lnd/clock v1.1.1 - github.com/lightningnetwork/lnd/fn/v2 v2.0.2 + github.com/lightningnetwork/lnd/fn/v2 v2.0.4 github.com/lightningnetwork/lnd/healthcheck v1.2.6 github.com/lightningnetwork/lnd/kvdb v1.4.11 github.com/lightningnetwork/lnd/queue v1.1.1 diff --git a/go.sum b/go.sum index 69c57c20e8..8d38c03b0a 100644 --- a/go.sum +++ b/go.sum @@ -456,8 +456,8 @@ github.com/lightningnetwork/lnd/cert v1.2.2 h1:71YK6hogeJtxSxw2teq3eGeuy4rHGKcFf github.com/lightningnetwork/lnd/cert v1.2.2/go.mod h1:jQmFn/Ez4zhDgq2hnYSw8r35bqGVxViXhX6Cd7HXM6U= github.com/lightningnetwork/lnd/clock v1.1.1 h1:OfR3/zcJd2RhH0RU+zX/77c0ZiOnIMsDIBjgjWdZgA0= github.com/lightningnetwork/lnd/clock v1.1.1/go.mod h1:mGnAhPyjYZQJmebS7aevElXKTFDuO+uNFFfMXK1W8xQ= -github.com/lightningnetwork/lnd/fn/v2 v2.0.2 h1:M7o2lYrh/zCp+lntPB3WP/rWTu5U+4ssyHW+kqNJ0fs= -github.com/lightningnetwork/lnd/fn/v2 v2.0.2/go.mod h1:TOzwrhjB/Azw1V7aa8t21ufcQmdsQOQMDtxVOQWNl8s= +github.com/lightningnetwork/lnd/fn/v2 v2.0.4 h1:DiC/AEa7DhnY4qOEQBISu1cp+1+51LjbVDzNLVBwNjI= +github.com/lightningnetwork/lnd/fn/v2 v2.0.4/go.mod h1:TOzwrhjB/Azw1V7aa8t21ufcQmdsQOQMDtxVOQWNl8s= github.com/lightningnetwork/lnd/healthcheck v1.2.6 h1:1sWhqr93GdkWy4+6U7JxBfcyZIE78MhIHTJZfPx7qqI= github.com/lightningnetwork/lnd/healthcheck v1.2.6/go.mod h1:Mu02um4CWY/zdTOvFje7WJgJcHyX2zq/FG3MhOAiGaQ= github.com/lightningnetwork/lnd/kvdb v1.4.11 h1:fk1HMVFrsVK3xqU7q+JWHRgBltw/a2qIg1E3zazMb/8= diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index 4c54fab0a5..1e96c586ce 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -2,6 +2,7 @@ package htlcswitch import ( "bytes" + "context" "errors" "fmt" "math/rand" @@ -85,6 +86,9 @@ var ( // fail payments if they increase our fee exposure. This is currently // set to 500m msats. DefaultMaxFeeExposure = lnwire.MilliSatoshi(500_000_000) + + // background is a shortcut for context.Background. + background = context.Background() ) // plexPacket encapsulates switch packet and adds error channel to receive @@ -245,8 +249,8 @@ type Switch struct { // This will be retrieved by the registered links atomically. bestHeight uint32 - wg sync.WaitGroup - quit chan struct{} + // gm starts and stops tasks in goroutines and waits for them. + gm *fn.GoroutineManager // cfg is a copy of the configuration struct that the htlc switch // service was initialized with. @@ -368,8 +372,11 @@ func New(cfg Config, currentHeight uint32) (*Switch, error) { return nil, err } + gm := fn.NewGoroutineManager() + s := &Switch{ bestHeight: currentHeight, + gm: gm, cfg: &cfg, circuits: circuitMap, linkIndex: make(map[lnwire.ChannelID]ChannelLink), @@ -382,7 +389,6 @@ func New(cfg Config, currentHeight uint32) (*Switch, error) { chanCloseRequests: make(chan *ChanClose), resolutionMsgs: make(chan *resolutionMsg), resMsgStore: resStore, - quit: make(chan struct{}), } s.aliasToReal = make(map[lnwire.ShortChannelID]lnwire.ShortChannelID) @@ -420,14 +426,14 @@ func (s *Switch) ProcessContractResolution(msg contractcourt.ResolutionMsg) erro ResolutionMsg: msg, errChan: errChan, }: - case <-s.quit: + case <-s.gm.Done(): return ErrSwitchExiting } select { case err := <-errChan: return err - case <-s.quit: + case <-s.gm.Done(): return ErrSwitchExiting } } @@ -493,14 +499,11 @@ func (s *Switch) GetAttemptResult(attemptID uint64, paymentHash lntypes.Hash, // Since the attempt was known, we can start a goroutine that can // extract the result when it is available, and pass it on to the // caller. - s.wg.Add(1) - go func() { - defer s.wg.Done() - + ok := s.gm.Go(background, func(ctx context.Context) { var n *networkResult select { case n = <-nChan: - case <-s.quit: + case <-s.gm.Done(): // We close the result channel to signal a shutdown. We // don't send any result in this case since the HTLC is // still in flight. @@ -524,7 +527,11 @@ func (s *Switch) GetAttemptResult(attemptID uint64, paymentHash lntypes.Hash, return } resultChan <- result - }() + }) + // The switch shutting down is signaled by closing the channel. + if !ok { + close(resultChan) + } return resultChan, nil } @@ -704,12 +711,19 @@ func (s *Switch) ForwardPackets(linkQuit <-chan struct{}, select { case <-linkQuit: return nil - case <-s.quit: + + case <-s.gm.Done(): return nil + default: - // Spawn a goroutine to log the errors returned from failed packets. - s.wg.Add(1) - go s.logFwdErrs(&numSent, &wg, fwdChan) + // Spawn a goroutine to log the errors returned from failed + // packets. + ok := s.gm.Go(background, func(ctx context.Context) { + s.logFwdErrs(ctx, &numSent, &wg, fwdChan) + }) + if !ok { + return nil + } } // Make a first pass over the packets, forwarding any settles or fails. @@ -820,8 +834,8 @@ func (s *Switch) ForwardPackets(linkQuit <-chan struct{}, } // logFwdErrs logs any errors received on `fwdChan`. -func (s *Switch) logFwdErrs(num *int, wg *sync.WaitGroup, fwdChan chan error) { - defer s.wg.Done() +func (s *Switch) logFwdErrs(ctx context.Context, num *int, wg *sync.WaitGroup, + fwdChan chan error) { // Wait here until the outer function has finished persisting // and routing the packets. This guarantees we don't read from num until @@ -836,7 +850,8 @@ func (s *Switch) logFwdErrs(num *int, wg *sync.WaitGroup, fwdChan chan error) { log.Errorf("Unhandled error while reforwarding htlc "+ "settle/fail over htlcswitch: %v", err) } - case <-s.quit: + + case <-s.gm.Done(): log.Errorf("unable to forward htlc packet " + "htlc switch was stopped") return @@ -862,7 +877,7 @@ func (s *Switch) routeAsync(packet *htlcPacket, errChan chan error, return nil case <-linkQuit: return ErrLinkShuttingDown - case <-s.quit: + case <-s.gm.Done(): return errors.New("htlc switch was stopped") } } @@ -940,8 +955,6 @@ func (s *Switch) getLocalLink(pkt *htlcPacket, htlc *lnwire.UpdateAddHTLC) ( // // NOTE: This method MUST be spawned as a goroutine. func (s *Switch) handleLocalResponse(pkt *htlcPacket) { - defer s.wg.Done() - attemptID := pkt.incomingHTLCID // The error reason will be unencypted in case this a local @@ -1436,7 +1449,7 @@ func (s *Switch) CloseLink(chanPoint *wire.OutPoint, case s.chanCloseRequests <- command: return updateChan, errChan - case <-s.quit: + case <-s.gm.Done(): errChan <- ErrSwitchExiting close(updateChan) return updateChan, errChan @@ -1454,8 +1467,6 @@ func (s *Switch) CloseLink(chanPoint *wire.OutPoint, // // NOTE: This MUST be run as a goroutine. func (s *Switch) htlcForwarder() { - defer s.wg.Done() - defer func() { s.blockEpochStream.Cancel() @@ -1489,6 +1500,8 @@ func (s *Switch) htlcForwarder() { var wg sync.WaitGroup for _, link := range linksToStop { wg.Add(1) + // Here it is ok to start a goroutine directly bypassing + // s.gm, because we want for them to complete here. go func(l ChannelLink) { defer wg.Done() @@ -1628,15 +1641,16 @@ out: // collect all the forwarding events since the last internal, // and write them out to our log. case <-s.cfg.FwdEventTicker.Ticks(): - s.wg.Add(1) - go func() { - defer s.wg.Done() - - if err := s.FlushForwardingEvents(); err != nil { + // The error of Go is ignored: if it is shutting down, + // the loop will terminate on the next iteration, in + // s.gm.Done case. + _ = s.gm.Go(background, func(ctx context.Context) { + err := s.FlushForwardingEvents() + if err != nil { log.Errorf("Unable to flush "+ "forwarding events: %v", err) } - }() + }) // The log ticker has fired, so we'll calculate some forwarding // stats for the last 10 seconds to display within the logs to @@ -1739,7 +1753,7 @@ out: // memory. s.pendingSettleFails = s.pendingSettleFails[:0] - case <-s.quit: + case <-s.gm.Done(): return } } @@ -1749,6 +1763,7 @@ out: func (s *Switch) Start() error { if !atomic.CompareAndSwapInt32(&s.started, 0, 1) { log.Warn("Htlc Switch already started") + return errors.New("htlc switch already started") } @@ -1760,12 +1775,24 @@ func (s *Switch) Start() error { } s.blockEpochStream = blockEpochStream - s.wg.Add(1) - go s.htlcForwarder() + ok := s.gm.Go(background, func(ctx context.Context) { + s.htlcForwarder() + }) + if !ok { + // We are already stopping so we can ignore the error. + _ = s.Stop() + err = fmt.Errorf("unable to start htlc forwarder: %w", + ErrSwitchExiting) + log.Errorf("%v", err) + + return err + } if err := s.reforwardResponses(); err != nil { - s.Stop() + // We are already stopping so we can ignore the error. + _ = s.Stop() log.Errorf("unable to reforward responses: %v", err) + return err } @@ -1773,6 +1800,7 @@ func (s *Switch) Start() error { // We are already stopping so we can ignore the error. _ = s.Stop() log.Errorf("unable to reforward resolutions: %v", err) + return err } @@ -1991,9 +2019,8 @@ func (s *Switch) Stop() error { log.Info("HTLC Switch shutting down...") defer log.Debug("HTLC Switch shutdown complete") - close(s.quit) - - s.wg.Wait() + // Ask running goroutines to stop and wait for them. + s.gm.Stop() // Wait until all active goroutines have finished exiting before // stopping the mailboxes, otherwise the mailbox map could still be @@ -2349,7 +2376,7 @@ func (s *Switch) RemoveLink(chanID lnwire.ChannelID) { select { case <-stopChan: return - case <-s.quit: + case <-s.gm.Done(): return } } @@ -3020,8 +3047,12 @@ func (s *Switch) handlePacketSettle(packet *htlcPacket) error { // NOTE: `closeCircuit` modifies the state of `packet`. if localHTLC { // TODO(yy): remove the goroutine and send back the error here. - s.wg.Add(1) - go s.handleLocalResponse(packet) + ok := s.gm.Go(background, func(ctx context.Context) { + s.handleLocalResponse(packet) + }) + if !ok { + return ErrSwitchExiting + } // If this is a locally initiated HTLC, there's no need to // forward it so we exit. @@ -3076,8 +3107,12 @@ func (s *Switch) handlePacketFail(packet *htlcPacket, // NOTE: `closeCircuit` modifies the state of `packet`. if packet.incomingChanID == hop.Source { // TODO(yy): remove the goroutine and send back the error here. - s.wg.Add(1) - go s.handleLocalResponse(packet) + ok := s.gm.Go(background, func(ctx context.Context) { + s.handleLocalResponse(packet) + }) + if !ok { + return ErrSwitchExiting + } // If this is a locally initiated HTLC, there's no need to // forward it so we exit. diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index 8809321460..fbd9bf4bfa 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -8,6 +8,7 @@ import ( "io" mrand "math/rand" "reflect" + "sync" "testing" "time" @@ -3159,6 +3160,60 @@ func TestSwitchGetAttemptResult(t *testing.T) { } } +// TestSwitchGetAttemptResultStress runs series of GetAttemptResult and Stop in +// parallel to make sure there is no race condition between these actions. +func TestSwitchGetAttemptResultStress(t *testing.T) { + t.Parallel() + + const paymentID = 123 + + s, err := initSwitchWithTempDB(t, testStartingHeight) + require.NoError(t, err, "unable to init switch") + require.NoError(t, s.Start(), "unable to start switch") + + lookup := make(chan *PaymentCircuit, 1) + s.circuits = &mockCircuitMap{ + lookup: lookup, + } + + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + + for range 1000 { + // Next let the lookup find the circuit in the circuit + // map. It should subscribe to payment results, and + // return the result when available. + lookup <- &PaymentCircuit{} + _, err := s.GetAttemptResult( + paymentID, lntypes.Hash{}, + newMockDeobfuscator(), + ) + require.NoError(t, err, "unable to get payment result") + } + }() + + // Run s.Stop() in parallel with consecutive GetAttemptResult calls to + // make sure this doesn't result in a race condition. + wg.Add(1) + go func() { + defer wg.Done() + + // Sleep 10ms to let several GetAttemptResult calls happen, so + // s.Stop() happens in the middle of GetAttemptResult series. + // The value 10ms was found empirically - this time is needed + // to expose the race condition (as a crash under -race) in the + // version of Switch before GoroutineManager was added. + time.Sleep(10 * time.Millisecond) + + require.NoError(t, s.Stop()) + }() + + wg.Wait() +} + // TestInvalidFailure tests that the switch returns an unreadable failure error // if the failure cannot be decrypted. func TestInvalidFailure(t *testing.T) { @@ -4953,7 +5008,7 @@ func testSwitchForwardFailAlias(t *testing.T, zeroConf bool) { // Pull packet from Bob's link, and do nothing with it. select { case <-bobLink.packets: - case <-s.quit: + case <-s.gm.Done(): t.Fatal("switch shutting down, failed to forward packet") } @@ -5012,7 +5067,8 @@ func testSwitchForwardFailAlias(t *testing.T, zeroConf bool) { failMsg, ok := msg.(*lnwire.FailTemporaryChannelFailure) require.True(t, ok) require.Equal(t, aliceAlias, failMsg.Update.ShortChannelID) - case <-s2.quit: + + case <-s2.gm.Done(): t.Fatal("switch shutting down, failed to forward packet") } } @@ -5193,7 +5249,8 @@ func testSwitchAliasFailAdd(t *testing.T, zeroConf, private, useAlias bool) { failMsg, ok := msg.(*lnwire.FailTemporaryChannelFailure) require.True(t, ok) require.Equal(t, outgoingChanID, failMsg.Update.ShortChannelID) - case <-s.quit: + + case <-s.gm.Done(): t.Fatal("switch shutting down, failed to receive fail packet") } } @@ -5393,7 +5450,8 @@ func testSwitchHandlePacketForward(t *testing.T, zeroConf, private, failMsg, ok := msg.(*lnwire.FailAmountBelowMinimum) require.True(t, ok) require.Equal(t, outgoingChanID, failMsg.Update.ShortChannelID) - case <-s.quit: + + case <-s.gm.Done(): t.Fatal("switch shutting down, failed to receive failure") } } @@ -5549,7 +5607,7 @@ func testSwitchAliasInterceptFail(t *testing.T, zeroConf bool) { isAlias := failScid == aliceAlias || failScid == aliceAlias2 require.True(t, isAlias) - case <-s.quit: + case <-s.gm.Done(): t.Fatalf("switch shutting down, failed to receive failure") } diff --git a/protofsm/state_machine.go b/protofsm/state_machine.go index 2cc1219022..5a7d86b792 100644 --- a/protofsm/state_machine.go +++ b/protofsm/state_machine.go @@ -25,6 +25,9 @@ var ( // ErrStateMachineShutdown occurs when trying to feed an event to a // StateMachine that has been asked to Stop. ErrStateMachineShutdown = fmt.Errorf("StateMachine is shutting down") + + // background is a shortcut for context.Background. + background = context.Background() ) // EmittedEvent is a special type that can be emitted by a state transition. @@ -200,7 +203,7 @@ func NewStateMachine[Event any, Env Environment](cfg StateMachineCfg[Event, Env] cfg: cfg, events: make(chan Event, 1), stateQuery: make(chan stateQuery[Event, Env]), - wg: *fn.NewGoroutineManager(context.Background()), + wg: *fn.NewGoroutineManager(), newStateEvents: fn.NewEventDistributor[State[Event, Env]](), quit: make(chan struct{}), } @@ -210,7 +213,7 @@ func NewStateMachine[Event any, Env Environment](cfg StateMachineCfg[Event, Env] // the state machine to completion. func (s *StateMachine[Event, Env]) Start() { s.startOnce.Do(func() { - _ = s.wg.Go(func(ctx context.Context) { + _ = s.wg.Go(background, func(ctx context.Context) { s.driveMachine() }) }) @@ -355,15 +358,19 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( // If a post-send event was specified, then we'll funnel // that back into the main state machine now as well. return fn.MapOptionZ(daemonEvent.PostSendEvent, func(event Event) error { //nolint:ll - launched := s.wg.Go(func(ctx context.Context) { - log.Debugf("FSM(%v): sending "+ - "post-send event: %v", - s.cfg.Env.Name(), - lnutils.SpewLogClosure(event), - ) - - s.SendEvent(event) - }) + launched := s.wg.Go( + background, func(ctx context.Context) { + log.Debugf("FSM(%v): sending "+ + "post-send event: %v", + s.cfg.Env.Name(), + lnutils.SpewLogClosure( + event, + ), + ) + + s.SendEvent(event) + }, + ) if !launched { return ErrStateMachineShutdown @@ -382,7 +389,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( // Otherwise, this has a SendWhen predicate, so we'll need // launch a goroutine to poll the SendWhen, then send only once // the predicate is true. - launched := s.wg.Go(func(ctx context.Context) { + launched := s.wg.Go(background, func(ctx context.Context) { predicateTicker := time.NewTicker( s.cfg.CustomPollInterval.UnwrapOr(pollInterval), ) @@ -456,7 +463,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( return fmt.Errorf("unable to register spend: %w", err) } - launched := s.wg.Go(func(ctx context.Context) { + launched := s.wg.Go(background, func(ctx context.Context) { for { select { case spend, ok := <-spendEvent.Spend: @@ -502,7 +509,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( return fmt.Errorf("unable to register conf: %w", err) } - launched := s.wg.Go(func(ctx context.Context) { + launched := s.wg.Go(background, func(ctx context.Context) { for { select { case <-confEvent.Confirmed: