diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index 0e736f452f..1d2268afe0 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -804,7 +804,7 @@ func (c *ChannelArbitrator) relaunchResolvers(commitSet *CommitSet, // TODO(roasbeef): this isn't re-launched? } - c.launchResolvers(unresolvedContracts) + c.resolveContracts(unresolvedContracts) return nil } @@ -1343,7 +1343,7 @@ func (c *ChannelArbitrator) stateStep( // Finally, we'll launch all the required contract resolvers. // Once they're all resolved, we're no longer needed. - c.launchResolvers(resolvers) + c.resolveContracts(resolvers) nextState = StateWaitingFullResolution @@ -1566,18 +1566,47 @@ func (c *ChannelArbitrator) findCommitmentDeadlineAndValue(heightHint uint32, return fn.Some(int32(deadline)), valueLeft, nil } -// launchResolvers updates the activeResolvers list and starts the resolvers. -func (c *ChannelArbitrator) launchResolvers(resolvers []ContractResolver) { +// resolveContracts updates the activeResolvers list and starts to resolve each +// contract concurrently, and launches them. +func (c *ChannelArbitrator) resolveContracts(resolvers []ContractResolver) { c.activeResolversLock.Lock() c.activeResolvers = resolvers c.activeResolversLock.Unlock() + // Launch all resolvers. + c.launchResolvers() + for _, contract := range resolvers { c.wg.Add(1) go c.resolveContract(contract) } } +// launchResolvers launches all the active resolvers. +func (c *ChannelArbitrator) launchResolvers() { + c.activeResolversLock.Lock() + resolvers := c.activeResolvers + c.activeResolversLock.Unlock() + + for _, contract := range resolvers { + // If the contract is already resolved, there's no need to + // launch it again. + if contract.IsResolved() { + log.Debugf("ChannelArbitrator(%v): skipping resolver "+ + "%T as it's already resolved", c.cfg.ChanPoint, + contract) + + continue + } + + if err := contract.Launch(); err != nil { + log.Errorf("ChannelArbitrator(%v): unable to launch "+ + "contract resolver(%T): %v", c.cfg.ChanPoint, + contract, err) + } + } +} + // advanceState is the main driver of our state machine. This method is an // iterative function which repeatedly attempts to advance the internal state // of the channel arbitrator. The state will be advanced until we reach a @@ -2616,6 +2645,13 @@ func (c *ChannelArbitrator) resolveContract(currentContract ContractResolver) { // loop. currentContract = nextContract + // Launch the new contract. + err = currentContract.Launch() + if err != nil { + log.Errorf("Failed to launch %T: %v", + currentContract, err) + } + // If this contract is actually fully resolved, then // we'll mark it as such within the database. case currentContract.IsResolved(): @@ -3092,6 +3128,9 @@ func (c *ChannelArbitrator) handleBlockbeat(beat chainio.Blockbeat) error { } } + // Launch all active resolvers when a new blockbeat is received. + c.launchResolvers() + return nil } diff --git a/contractcourt/channel_arbitrator_test.go b/contractcourt/channel_arbitrator_test.go index aa16820025..9e47569213 100644 --- a/contractcourt/channel_arbitrator_test.go +++ b/contractcourt/channel_arbitrator_test.go @@ -13,6 +13,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" + "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/chainio" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" @@ -1084,12 +1085,7 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) { } // Send a notification that the expiry height has been reached. - // - // TODO(yy): remove the EpochChan and use the blockbeat below once - // resolvers are hooked with the blockbeat. oldNotifier.EpochChan <- &chainntnfs.BlockEpoch{Height: 10} - // beat := chainio.NewBlockbeatFromHeight(10) - // chanArb.BlockbeatChan <- beat // htlcOutgoingContestResolver is now transforming into a // htlcTimeoutResolver and should send the contract off for incubation. @@ -1132,8 +1128,12 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) { default: } - // Notify resolver that the second level transaction is spent. - oldNotifier.SpendChan <- &chainntnfs.SpendDetail{SpendingTx: closeTx} + // Notify resolver that the output of the timeout tx has been spent. + oldNotifier.SpendChan <- &chainntnfs.SpendDetail{ + SpendingTx: closeTx, + SpentOutPoint: &wire.OutPoint{}, + SpenderTxHash: &closeTxid, + } // At this point channel should be marked as resolved. chanArbCtxNew.AssertStateTransitions(StateFullyResolved) @@ -2806,27 +2806,28 @@ func TestChannelArbitratorAnchors(t *testing.T) { } chanArb.UpdateContractSignals(signals) - // Set current block height. - beat = newBeatFromHeight(int32(heightHint)) - chanArbCtx.chanArb.BlockbeatChan <- beat - htlcAmt := lnwire.MilliSatoshi(1_000_000) // Create testing HTLCs. - deadlineDelta := uint32(10) - deadlinePreimageDelta := deadlineDelta + 2 + spendingHeight := uint32(beat.Height()) + deadlineDelta := uint32(100) + + deadlinePreimageDelta := deadlineDelta htlcWithPreimage := channeldb.HTLC{ - HtlcIndex: 99, - RefundTimeout: heightHint + deadlinePreimageDelta, + HtlcIndex: 99, + // RefundTimeout is 101. + RefundTimeout: spendingHeight + deadlinePreimageDelta, RHash: rHash, Incoming: true, Amt: htlcAmt, } + expectedDeadline := deadlineDelta/2 + spendingHeight - deadlineHTLCdelta := deadlineDelta + 3 + deadlineHTLCdelta := deadlineDelta + 40 htlc := channeldb.HTLC{ - HtlcIndex: 100, - RefundTimeout: heightHint + deadlineHTLCdelta, + HtlcIndex: 100, + // RefundTimeout is 141. + RefundTimeout: spendingHeight + deadlineHTLCdelta, Amt: htlcAmt, } @@ -2910,7 +2911,9 @@ func TestChannelArbitratorAnchors(t *testing.T) { } chanArb.cfg.ChainEvents.LocalUnilateralClosure <- &LocalUnilateralCloseInfo{ - SpendDetail: &chainntnfs.SpendDetail{}, + SpendDetail: &chainntnfs.SpendDetail{ + SpendingHeight: int32(spendingHeight), + }, LocalForceCloseSummary: &lnwallet.LocalForceCloseSummary{ CloseTx: closeTx, HtlcResolutions: &lnwallet.HtlcResolutions{}, @@ -2972,16 +2975,14 @@ func TestChannelArbitratorAnchors(t *testing.T) { // to htlcWithPreimage's CLTV. require.Equal(t, 2, len(chanArbCtx.sweeper.deadlines)) require.EqualValues(t, - heightHint+deadlinePreimageDelta/2, + expectedDeadline, chanArbCtx.sweeper.deadlines[0], "want %d, got %d", - heightHint+deadlinePreimageDelta/2, - chanArbCtx.sweeper.deadlines[0], + expectedDeadline, chanArbCtx.sweeper.deadlines[0], ) require.EqualValues(t, - heightHint+deadlinePreimageDelta/2, + expectedDeadline, chanArbCtx.sweeper.deadlines[1], "want %d, got %d", - heightHint+deadlinePreimageDelta/2, - chanArbCtx.sweeper.deadlines[1], + expectedDeadline, chanArbCtx.sweeper.deadlines[1], ) } @@ -3129,7 +3130,8 @@ func assertResolverReport(t *testing.T, reports chan *channeldb.ResolverReport, select { case report := <-reports: if !reflect.DeepEqual(report, expected) { - t.Fatalf("expected: %v, got: %v", expected, report) + t.Fatalf("expected: %v, got: %v", spew.Sdump(expected), + spew.Sdump(report)) } case <-time.After(defaultTimeout): diff --git a/contractcourt/contract_resolver.go b/contractcourt/contract_resolver.go index 0ed5893329..d001dd6ebe 100644 --- a/contractcourt/contract_resolver.go +++ b/contractcourt/contract_resolver.go @@ -38,6 +38,17 @@ type ContractResolver interface { // resides within. ResolverKey() []byte + // Launch starts the resolver by constructing an input and offering it + // to the sweeper. Once offered, it's expected to monitor the sweeping + // result in a goroutine invoked by calling Resolve. + // + // NOTE: We can call `Resolve` inside a goroutine at the end of this + // method to avoid calling it in the ChannelArbitrator. However, there + // are some DB-related operations such as SwapContract/ResolveContract + // which need to be done inside the resolvers instead, which needs a + // deeper refactoring. + Launch() error + // Resolve instructs the contract resolver to resolve the output // on-chain. Once the output has been *fully* resolved, the function // should return immediately with a nil ContractResolver value for the diff --git a/contractcourt/htlc_success_resolver_test.go b/contractcourt/htlc_success_resolver_test.go index b92fcb678f..92ebecd836 100644 --- a/contractcourt/htlc_success_resolver_test.go +++ b/contractcourt/htlc_success_resolver_test.go @@ -146,6 +146,9 @@ func (i *htlcResolverTestContext) resolve() { i.resolverResultChan = make(chan resolveResult, 1) go func() { + err := i.resolver.Launch() + require.NoError(i.t, err) + nextResolver, err := i.resolver.Resolve() i.resolverResultChan <- resolveResult{ nextResolver: nextResolver,