Skip to content

Commit

Permalink
htlcswitch: use fn.GoroutineManager
Browse files Browse the repository at this point in the history
Replaced the use of s.quit and s.wg with s.gm (GoroutineManager). WaitGroup is
still needed to wait for handleLocalResponse: if it was switched to s.gm, then
it may skip running, which has unclear consequences. After handleLocalResponse
is changed to run without a goroutine, we can remove WaitGroup completely.

This fixes a race condition between s.wg.Add(1) and s.wg.Wait().
Also added a test which used to fail under `-race` before this commit.
  • Loading branch information
starius committed Oct 11, 2024
1 parent 124f087 commit 39b7a3b
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 41 deletions.
98 changes: 62 additions & 36 deletions htlcswitch/switch.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package htlcswitch

import (
"bytes"
"context"
"errors"
"fmt"
"math/rand"
Expand Down Expand Up @@ -245,8 +246,14 @@ type Switch struct {
// This will be retrieved by the registered links atomically.
bestHeight uint32

wg sync.WaitGroup
quit chan struct{}
// TODO(yy): remove handleLocalResponseWG, once handleLocalResponse runs
// without a goroutine. Currently we can't run handleLocalResponse in
// gm, since if gm is stopping, the goroutine won't start and it is
// unclear if it safe to skip handleLocalResponse.
handleLocalResponseWG sync.WaitGroup

// gm starts and stops tasks in goroutines and waits for them.
gm *fn.GoroutineManager

// cfg is a copy of the configuration struct that the htlc switch
// service was initialized with.
Expand Down Expand Up @@ -370,6 +377,7 @@ func New(cfg Config, currentHeight uint32) (*Switch, error) {

s := &Switch{
bestHeight: currentHeight,
gm: fn.NewGoroutineManager(context.Background()),
cfg: &cfg,
circuits: circuitMap,
linkIndex: make(map[lnwire.ChannelID]ChannelLink),
Expand All @@ -382,7 +390,6 @@ func New(cfg Config, currentHeight uint32) (*Switch, error) {
chanCloseRequests: make(chan *ChanClose),
resolutionMsgs: make(chan *resolutionMsg),
resMsgStore: resStore,
quit: make(chan struct{}),
}

s.aliasToReal = make(map[lnwire.ShortChannelID]lnwire.ShortChannelID)
Expand Down Expand Up @@ -420,14 +427,14 @@ func (s *Switch) ProcessContractResolution(msg contractcourt.ResolutionMsg) erro
ResolutionMsg: msg,
errChan: errChan,
}:
case <-s.quit:
case <-s.gm.Done():
return ErrSwitchExiting
}

select {
case err := <-errChan:
return err
case <-s.quit:
case <-s.gm.Done():
return ErrSwitchExiting
}
}
Expand Down Expand Up @@ -493,14 +500,11 @@ func (s *Switch) GetAttemptResult(attemptID uint64, paymentHash lntypes.Hash,
// Since the attempt was known, we can start a goroutine that can
// extract the result when it is available, and pass it on to the
// caller.
s.wg.Add(1)
go func() {
defer s.wg.Done()

err = s.gm.Go(func(ctx context.Context) {
var n *networkResult
select {
case n = <-nChan:
case <-s.quit:
case <-ctx.Done():
// We close the result channel to signal a shutdown. We
// don't send any result in this case since the HTLC is
// still in flight.
Expand All @@ -524,7 +528,12 @@ func (s *Switch) GetAttemptResult(attemptID uint64, paymentHash lntypes.Hash,
return
}
resultChan <- result
}()
})

// The switch shutting down is signaled by closing the channel.
if err != nil {
close(resultChan)
}

return resultChan, nil
}
Expand Down Expand Up @@ -704,12 +713,18 @@ func (s *Switch) ForwardPackets(linkQuit chan struct{},
select {
case <-linkQuit:
return nil
case <-s.quit:

case <-s.gm.Done():
return nil

default:
// Spawn a goroutine to log the errors returned from failed packets.
s.wg.Add(1)
go s.logFwdErrs(&numSent, &wg, fwdChan)
err := s.gm.Go(func(ctx context.Context) {
s.logFwdErrs(ctx, &numSent, &wg, fwdChan)
})
if err != nil {
return nil
}
}

// Make a first pass over the packets, forwarding any settles or fails.
Expand Down Expand Up @@ -820,8 +835,8 @@ func (s *Switch) ForwardPackets(linkQuit chan struct{},
}

// logFwdErrs logs any errors received on `fwdChan`.
func (s *Switch) logFwdErrs(num *int, wg *sync.WaitGroup, fwdChan chan error) {
defer s.wg.Done()
func (s *Switch) logFwdErrs(ctx context.Context, num *int, wg *sync.WaitGroup,
fwdChan chan error) {

// Wait here until the outer function has finished persisting
// and routing the packets. This guarantees we don't read from num until
Expand All @@ -836,7 +851,8 @@ func (s *Switch) logFwdErrs(num *int, wg *sync.WaitGroup, fwdChan chan error) {
log.Errorf("Unhandled error while reforwarding htlc "+
"settle/fail over htlcswitch: %v", err)
}
case <-s.quit:

case <-ctx.Done():
log.Errorf("unable to forward htlc packet " +
"htlc switch was stopped")
return
Expand All @@ -862,7 +878,7 @@ func (s *Switch) routeAsync(packet *htlcPacket, errChan chan error,
return nil
case <-linkQuit:
return ErrLinkShuttingDown
case <-s.quit:
case <-s.gm.Done():
return errors.New("htlc switch was stopped")
}
}
Expand Down Expand Up @@ -939,7 +955,7 @@ func (s *Switch) getLocalLink(pkt *htlcPacket, htlc *lnwire.UpdateAddHTLC) (
//
// NOTE: This method MUST be spawned as a goroutine.
func (s *Switch) handleLocalResponse(pkt *htlcPacket) {
defer s.wg.Done()
defer s.handleLocalResponseWG.Done()

attemptID := pkt.incomingHTLCID

Expand Down Expand Up @@ -1435,7 +1451,7 @@ func (s *Switch) CloseLink(chanPoint *wire.OutPoint,
case s.chanCloseRequests <- command:
return updateChan, errChan

case <-s.quit:
case <-s.gm.Done():
errChan <- ErrSwitchExiting
close(updateChan)
return updateChan, errChan
Expand All @@ -1453,8 +1469,6 @@ func (s *Switch) CloseLink(chanPoint *wire.OutPoint,
//
// NOTE: This MUST be run as a goroutine.
func (s *Switch) htlcForwarder() {
defer s.wg.Done()

defer func() {
s.blockEpochStream.Cancel()

Expand Down Expand Up @@ -1488,6 +1502,8 @@ func (s *Switch) htlcForwarder() {
var wg sync.WaitGroup
for _, link := range linksToStop {
wg.Add(1)
// Here it is ok to start a goroutine directly bypassing
// s.gm, because we want for them to complete here.
go func(l ChannelLink) {
defer wg.Done()

Expand Down Expand Up @@ -1627,15 +1643,16 @@ out:
// collect all the forwarding events since the last internal,
// and write them out to our log.
case <-s.cfg.FwdEventTicker.Ticks():
s.wg.Add(1)
go func() {
defer s.wg.Done()

if err := s.FlushForwardingEvents(); err != nil {
err := s.gm.Go(func(ctx context.Context) {
err := s.FlushForwardingEvents()
if err != nil {
log.Errorf("unable to flush "+
"forwarding events: %v", err)
}
}()
})
if err != nil {
return
}

// The log ticker has fired, so we'll calculate some forwarding
// stats for the last 10 seconds to display within the logs to
Expand Down Expand Up @@ -1738,7 +1755,7 @@ out:
// memory.
s.pendingSettleFails = s.pendingSettleFails[:0]

case <-s.quit:
case <-s.gm.Done():
return
}
}
Expand All @@ -1759,8 +1776,15 @@ func (s *Switch) Start() error {
}
s.blockEpochStream = blockEpochStream

s.wg.Add(1)
go s.htlcForwarder()
err = s.gm.Go(func(ctx context.Context) {
s.htlcForwarder()
})
if err != nil {
s.Stop()
err = fmt.Errorf("unable to start htlc forwarder: %w", err)
log.Errorf("%v", err)
return err
}

if err := s.reforwardResponses(); err != nil {
s.Stop()
Expand Down Expand Up @@ -1990,9 +2014,11 @@ func (s *Switch) Stop() error {
log.Info("HTLC Switch shutting down...")
defer log.Debug("HTLC Switch shutdown complete")

close(s.quit)
// Ask running goroutines to stop and wait for them.
s.gm.Stop()

s.wg.Wait()
// TODO(yy): remove this, when s.handleLocalResponseWG is removed.
s.handleLocalResponseWG.Wait()

// Wait until all active goroutines have finished exiting before
// stopping the mailboxes, otherwise the mailbox map could still be
Expand Down Expand Up @@ -2348,7 +2374,7 @@ func (s *Switch) RemoveLink(chanID lnwire.ChannelID) {
select {
case <-stopChan:
return
case <-s.quit:
case <-s.gm.Done():
return
}
}
Expand Down Expand Up @@ -3020,7 +3046,7 @@ func (s *Switch) handlePacketSettle(packet *htlcPacket) error {
// NOTE: `closeCircuit` modifies the state of `packet`.
if localHTLC {
// TODO(yy): remove the goroutine and send back the error here.
s.wg.Add(1)
s.handleLocalResponseWG.Add(1)
go s.handleLocalResponse(packet)

// If this is a locally initiated HTLC, there's no need to
Expand Down Expand Up @@ -3076,7 +3102,7 @@ func (s *Switch) handlePacketFail(packet *htlcPacket,
// NOTE: `closeCircuit` modifies the state of `packet`.
if packet.incomingChanID == hop.Source {
// TODO(yy): remove the goroutine and send back the error here.
s.wg.Add(1)
s.handleLocalResponseWG.Add(1)
go s.handleLocalResponse(packet)

// If this is a locally initiated HTLC, there's no need to
Expand Down
66 changes: 61 additions & 5 deletions htlcswitch/switch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"io"
mrand "math/rand"
"reflect"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -3180,6 +3181,58 @@ func TestSwitchGetAttemptResult(t *testing.T) {
}
}

func TestSwitchGetAttemptResultStress(t *testing.T) {
t.Parallel()

const paymentID = 123
var preimg lntypes.Preimage
preimg[0] = 3

s, err := initSwitchWithTempDB(t, testStartingHeight)
require.NoError(t, err, "unable to init switch")
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer func() {
require.NoError(t, s.Stop())
}()

lookup := make(chan *PaymentCircuit, 1)
s.circuits = &mockCircuitMap{
lookup: lookup,
}

var wg sync.WaitGroup

wg.Add(1)
go func() {
defer wg.Done()

for range 1000 {
// Next let the lookup find the circuit in the circuit
// map. It should subscribe to payment results, and
// return the result when available.
lookup <- &PaymentCircuit{}
_, err := s.GetAttemptResult(
paymentID, lntypes.Hash{},
newMockDeobfuscator(),
)
require.NoError(t, err, "unable to get payment result")
}
}()

wg.Add(1)
go func() {
defer wg.Done()

time.Sleep(10 * time.Millisecond)

require.NoError(t, s.Stop())
}()

wg.Wait()
}

// TestInvalidFailure tests that the switch returns an unreadable failure error
// if the failure cannot be decrypted.
func TestInvalidFailure(t *testing.T) {
Expand Down Expand Up @@ -4985,7 +5038,7 @@ func testSwitchForwardFailAlias(t *testing.T, zeroConf bool) {
// Pull packet from Bob's link, and do nothing with it.
select {
case <-bobLink.packets:
case <-s.quit:
case <-s.gm.Done():
t.Fatal("switch shutting down, failed to forward packet")
}

Expand Down Expand Up @@ -5046,7 +5099,8 @@ func testSwitchForwardFailAlias(t *testing.T, zeroConf bool) {
failMsg, ok := msg.(*lnwire.FailTemporaryChannelFailure)
require.True(t, ok)
require.Equal(t, aliceAlias, failMsg.Update.ShortChannelID)
case <-s2.quit:

case <-s2.gm.Done():
t.Fatal("switch shutting down, failed to forward packet")
}
}
Expand Down Expand Up @@ -5229,7 +5283,8 @@ func testSwitchAliasFailAdd(t *testing.T, zeroConf, private, useAlias bool) {
failMsg, ok := msg.(*lnwire.FailTemporaryChannelFailure)
require.True(t, ok)
require.Equal(t, outgoingChanID, failMsg.Update.ShortChannelID)
case <-s.quit:

case <-s.gm.Done():
t.Fatal("switch shutting down, failed to receive fail packet")
}
}
Expand Down Expand Up @@ -5429,7 +5484,8 @@ func testSwitchHandlePacketForward(t *testing.T, zeroConf, private,
failMsg, ok := msg.(*lnwire.FailAmountBelowMinimum)
require.True(t, ok)
require.Equal(t, outgoingChanID, failMsg.Update.ShortChannelID)
case <-s.quit:

case <-s.gm.Done():
t.Fatal("switch shutting down, failed to receive failure")
}
}
Expand Down Expand Up @@ -5587,7 +5643,7 @@ func testSwitchAliasInterceptFail(t *testing.T, zeroConf bool) {
isAlias := failScid == aliceAlias || failScid == aliceAlias2
require.True(t, isAlias)

case <-s.quit:
case <-s.gm.Done():
t.Fatalf("switch shutting down, failed to receive failure")
}

Expand Down

0 comments on commit 39b7a3b

Please sign in to comment.