diff --git a/multinode/transaction_sender.go b/multinode/transaction_sender.go index 49322d4..c0c8a9b 100644 --- a/multinode/transaction_sender.go +++ b/multinode/transaction_sender.go @@ -91,6 +91,8 @@ type TransactionSender[TX any, RESULT SendTxResult, CHAIN_ID ID, RPC SendTxRPCCl // * Otherwise, returns any (effectively random) of the errors. func (txSender *TransactionSender[TX, RESULT, CHAIN_ID, RPC]) SendTransaction(ctx context.Context, tx TX) RESULT { var result RESULT + ctx, cancel := txSender.chStop.Ctx(ctx) + defer cancel() if !txSender.IfStarted(func() { txResults := make(chan RESULT) txResultsToReport := make(chan RESULT) @@ -101,8 +103,6 @@ func (txSender *TransactionSender[TX, RESULT, CHAIN_ID, RPC]) SendTransaction(ct if isSendOnly { txSender.wg.Add(1) go func(ctx context.Context) { - ctx, cancel := txSender.chStop.Ctx(context.WithoutCancel(ctx)) - defer cancel() defer txSender.wg.Done() // Send-only nodes' results are ignored as they tend to return false-positive responses. // Broadcast to them is necessary to speed up the propagation of TX in the network. @@ -115,8 +115,9 @@ func (txSender *TransactionSender[TX, RESULT, CHAIN_ID, RPC]) SendTransaction(ct healthyNodesNum++ primaryNodeWg.Add(1) go func(ctx context.Context) { - ctx, cancel := txSender.chStop.Ctx(context.WithoutCancel(ctx)) - defer cancel() + // Broadcasting transaction and results reporting for invariant detection are background jobs that must be detached from + // callers cancellation. + // Results reporting to SendTransaction caller must respect caller's context to avoid goroutine leak. defer primaryNodeWg.Done() r := txSender.broadcastTxAsync(ctx, rpc, tx) select { @@ -126,6 +127,8 @@ func (txSender *TransactionSender[TX, RESULT, CHAIN_ID, RPC]) SendTransaction(ct case txResults <- r: } + ctx, cancel := txSender.chStop.Ctx(context.WithoutCancel(ctx)) + defer cancel() select { case <-ctx.Done(): txSender.lggr.Debugw("Failed to send tx results to report", "err", ctx.Err()) @@ -149,8 +152,13 @@ func (txSender *TransactionSender[TX, RESULT, CHAIN_ID, RPC]) SendTransaction(ct return } + if healthyNodesNum == 0 { + result = txSender.newResult(ErroringNodeError) + return + } + txSender.wg.Add(1) - go txSender.reportSendTxAnomalies(ctx, tx, txResultsToReport) + go txSender.reportSendTxAnomalies(tx, txResultsToReport) result = txSender.collectTxResults(ctx, tx, healthyNodesNum, txResults) }) { @@ -161,6 +169,9 @@ func (txSender *TransactionSender[TX, RESULT, CHAIN_ID, RPC]) SendTransaction(ct } func (txSender *TransactionSender[TX, RESULT, CHAIN_ID, RPC]) broadcastTxAsync(ctx context.Context, rpc RPC, tx TX) RESULT { + // broadcast is a background job, so always detach from caller's cancellation + ctx, cancel := txSender.chStop.Ctx(context.WithoutCancel(ctx)) + defer cancel() result := rpc.SendTransaction(ctx, tx) txSender.lggr.Debugw("Node sent transaction", "tx", tx, "err", result.Error()) if !slices.Contains(sendTxSuccessfulCodes, result.Code()) && ctx.Err() == nil { @@ -169,7 +180,7 @@ func (txSender *TransactionSender[TX, RESULT, CHAIN_ID, RPC]) broadcastTxAsync(c return result } -func (txSender *TransactionSender[TX, RESULT, CHAIN_ID, RPC]) reportSendTxAnomalies(ctx context.Context, tx TX, txResults <-chan RESULT) { +func (txSender *TransactionSender[TX, RESULT, CHAIN_ID, RPC]) reportSendTxAnomalies(tx TX, txResults <-chan RESULT) { defer txSender.wg.Done() resultsByCode := sendTxResults[RESULT]{} // txResults eventually will be closed @@ -177,8 +188,17 @@ func (txSender *TransactionSender[TX, RESULT, CHAIN_ID, RPC]) reportSendTxAnomal resultsByCode[txResult.Code()] = append(resultsByCode[txResult.Code()], txResult) } + select { + case <-txSender.chStop: + // it's ok to receive no results if txSender is closing. Return early to prevent false reporting of invariant violation. + if len(resultsByCode) == 0 { + return + } + default: + } + _, criticalErr := aggregateTxResults[RESULT](resultsByCode) - if criticalErr != nil && ctx.Err() == nil { + if criticalErr != nil { txSender.lggr.Criticalw("observed invariant violation on SendTransaction", "tx", tx, "resultsByCode", resultsByCode, "err", criticalErr) PromMultiNodeInvariantViolations.WithLabelValues(txSender.chainFamily, txSender.chainID.String(), criticalErr.Error()).Inc() } @@ -216,9 +236,6 @@ func aggregateTxResults[RESULT any](resultsByCode sendTxResults[RESULT]) (result } func (txSender *TransactionSender[TX, RESULT, CHAIN_ID, RPC]) collectTxResults(ctx context.Context, tx TX, healthyNodesNum int, txResults <-chan RESULT) RESULT { - if healthyNodesNum == 0 { - return txSender.newResult(ErroringNodeError) - } requiredResults := int(math.Ceil(float64(healthyNodesNum) * sendTxQuorum)) errorsByCode := sendTxResults[RESULT]{} var softTimeoutChan <-chan time.Time diff --git a/multinode/transaction_sender_test.go b/multinode/transaction_sender_test.go index 6405639..62d6523 100644 --- a/multinode/transaction_sender_test.go +++ b/multinode/transaction_sender_test.go @@ -292,6 +292,27 @@ func TestTransactionSender_SendTransaction(t *testing.T) { require.NoError(t, result.Error()) require.Equal(t, Successful, result.Code()) }) + t.Run("All background jobs stop even if RPC returns result after soft timeout", func(t *testing.T) { + chainID := RandomID() + expectedError := errors.New("transaction failed") + fastNode := newNode(t, expectedError, nil) + + // hold reply from the node till SendTransaction returns result + sendTxContext, sendTxCancel := context.WithCancel(tests.Context(t)) + slowNode := newNode(t, errors.New("transaction failed"), func(_ mock.Arguments) { + <-sendTxContext.Done() + }) + + lggr := logger.Test(t) + + _, txSender := newTestTransactionSender(t, chainID, lggr, []Node[ID, TestSendTxRPCClient]{fastNode, slowNode}, nil) + result := txSender.SendTransaction(sendTxContext, nil) + sendTxCancel() + require.EqualError(t, result.Error(), expectedError.Error()) + // TxSender should stop all background go routines after SendTransaction is done and before test is done. + // Otherwise, it signals that we have a goroutine leak. + txSender.wg.Wait() + }) } func TestTransactionSender_SendTransaction_aggregateTxResults(t *testing.T) {