From c17a0b0495d28350ae610eca7433077f5d78d12f Mon Sep 17 00:00:00 2001 From: Ian Shim <100327837+ian-shim@users.noreply.github.com> Date: Tue, 4 Jun 2024 07:10:45 +0900 Subject: [PATCH] Filter empty quorums when confirming batch (#581) --- core/aggregation.go | 137 ++++++++++++++----- core/aggregation_test.go | 130 +++++++++++++++++- core/mock/state.go | 4 +- core/mock/tx.go | 2 +- disperser/batcher/batcher.go | 50 ++++--- disperser/batcher/batcher_test.go | 215 ++++++++++++++++++++++++++++-- disperser/mock/dispatcher.go | 35 +++-- test/integration_test.go | 3 +- 8 files changed, 495 insertions(+), 81 deletions(-) diff --git a/core/aggregation.go b/core/aggregation.go index 4a17525748..963c51b454 100644 --- a/core/aggregation.go +++ b/core/aggregation.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "math/big" + "slices" "sort" "github.com/Layr-Labs/eigensdk-go/logging" @@ -32,13 +33,30 @@ type SigningMessage struct { Err error } -// SignatureAggregation contains the results of aggregating signatures from a set of operators +// QuorumAttestation contains the results of aggregating signatures from a set of operators by quorums +// It also returns map of all signers across all quorums +type QuorumAttestation struct { + // QuorumAggPubKeys contains the aggregated public keys for all of the operators each quorum, + // including those that did not sign + QuorumAggPubKey map[QuorumID]*G1Point + // SignersAggPubKey is the aggregated public key for all of the operators that signed the message by each quorum + SignersAggPubKey map[QuorumID]*G2Point + // AggSignature is the aggregated signature for all of the operators that signed the message for each quorum, mirroring the + // SignersAggPubKey. + AggSignature map[QuorumID]*Signature + // QuorumResults contains the quorum ID and the amount signed for each quorum + QuorumResults map[QuorumID]*QuorumResult + // SignerMap contains the operator IDs that signed the message + SignerMap map[OperatorID]bool +} + +// SignatureAggregation contains the results of aggregating signatures from a set of operators across multiple quorums type SignatureAggregation struct { // NonSigners contains the public keys of the operators that did not sign the message NonSigners []*G1Point // QuorumAggPubKeys contains the aggregated public keys for all of the operators each quorum, // Including those that did not sign - QuorumAggPubKeys []*G1Point + QuorumAggPubKeys map[QuorumID]*G1Point // AggPubKey is the aggregated public key for all of the operators that signed the message, // further aggregated across the quorums; operators signing for multiple quorums will be included in // the aggregation multiple times @@ -48,16 +66,15 @@ type SignatureAggregation struct { AggSignature *Signature // QuorumResults contains the quorum ID and the amount signed for each quorum QuorumResults map[QuorumID]*QuorumResult - // SignerMap contains the operator IDs that signed the message - SignerMap map[OperatorID]bool } // SignatureAggregator is an interface for aggregating the signatures returned by DA nodes so that they can be verified by the DA contract type SignatureAggregator interface { - - // AggregateSignatures blocks until it receives a response for each operator in the operator state via messageChan, and then returns the aggregated signature. + // ReceiveSignatures blocks until it receives a response for each operator in the operator state via messageChan, and then returns the attestation result by quorum. + ReceiveSignatures(ctx context.Context, state *IndexedOperatorState, message [32]byte, messageChan chan SigningMessage) (*QuorumAttestation, error) + // AggregateSignatures takes attestation result by quorum and aggregates the signatures across them. // If the aggregated signature is invalid, an error is returned. - AggregateSignatures(ctx context.Context, state *IndexedOperatorState, quorumIDs []QuorumID, message [32]byte, messageChan chan SigningMessage) (*SignatureAggregation, error) + AggregateSignatures(ctx context.Context, ics IndexedChainState, referenceBlockNumber uint, quorumAttestation *QuorumAttestation, quorumIDs []QuorumID) (*SignatureAggregation, error) } type StdSignatureAggregator struct { @@ -82,8 +99,12 @@ func NewStdSignatureAggregator(logger logging.Logger, transactor Transactor) (*S var _ SignatureAggregator = (*StdSignatureAggregator)(nil) -func (a *StdSignatureAggregator) AggregateSignatures(ctx context.Context, state *IndexedOperatorState, quorumIDs []QuorumID, message [32]byte, messageChan chan SigningMessage) (*SignatureAggregation, error) { - // TODO: Add logging +func (a *StdSignatureAggregator) ReceiveSignatures(ctx context.Context, state *IndexedOperatorState, message [32]byte, messageChan chan SigningMessage) (*QuorumAttestation, error) { + quorumIDs := make([]QuorumID, 0, len(state.AggKeys)) + for quorumID := range state.Operators { + quorumIDs = append(quorumIDs, quorumID) + } + slices.Sort(quorumIDs) if len(quorumIDs) == 0 { return nil, errors.New("the number of quorums must be greater than zero") @@ -97,13 +118,12 @@ func (a *StdSignatureAggregator) AggregateSignatures(ctx context.Context, state } } - stakeSigned := make([]*big.Int, len(quorumIDs)) - for ind := range quorumIDs { - stakeSigned[ind] = big.NewInt(0) + stakeSigned := make(map[QuorumID]*big.Int, len(quorumIDs)) + for _, quorumID := range quorumIDs { + stakeSigned[quorumID] = big.NewInt(0) } - aggSigs := make([]*Signature, len(quorumIDs)) - aggPubKeys := make([]*G2Point, len(quorumIDs)) - + aggSigs := make(map[QuorumID]*Signature, len(quorumIDs)) + aggPubKeys := make(map[QuorumID]*G2Point, len(quorumIDs)) signerMap := make(map[OperatorID]bool) // Aggregate Signatures @@ -151,7 +171,7 @@ func (a *StdSignatureAggregator) AggregateSignatures(ctx context.Context, state } operatorQuorums := make([]uint8, 0, len(quorumIDs)) - for ind, quorumID := range quorumIDs { + for _, quorumID := range quorumIDs { // Get stake amounts for operator ops := state.Operators[quorumID] opInfo, ok := ops[r.Operator] @@ -164,15 +184,15 @@ func (a *StdSignatureAggregator) AggregateSignatures(ctx context.Context, state signerMap[r.Operator] = true // Add to stake signed - stakeSigned[ind].Add(stakeSigned[ind], opInfo.Stake) + stakeSigned[quorumID].Add(stakeSigned[quorumID], opInfo.Stake) // Add to agg signature - if aggSigs[ind] == nil { - aggSigs[ind] = &Signature{sig.Clone()} - aggPubKeys[ind] = op.PubkeyG2.Clone() + if aggSigs[quorumID] == nil { + aggSigs[quorumID] = &Signature{sig.Clone()} + aggPubKeys[quorumID] = op.PubkeyG2.Clone() } else { - aggSigs[ind].Add(sig.G1Point) - aggPubKeys[ind].Add(op.PubkeyG2) + aggSigs[quorumID].Add(sig.G1Point) + aggPubKeys[quorumID].Add(op.PubkeyG2) } } a.Logger.Info("received signature from operator", "operatorID", operatorIDHex, "operatorAddress", operatorAddr, "socket", socket, "quorumIDs", fmt.Sprint(operatorQuorums), "batchHeaderHash", batchHeaderHashHex, "attestationLatencyMs", r.AttestationLatencyMs) @@ -190,14 +210,14 @@ func (a *StdSignatureAggregator) AggregateSignatures(ctx context.Context, state } } - quorumAggPubKeys := make([]*G1Point, len(quorumIDs)) + quorumAggPubKeys := make(map[QuorumID]*G1Point, len(quorumIDs)) // Validate the amount signed and aggregate signatures for each quorum quorumResults := make(map[QuorumID]*QuorumResult) - for ind, quorumID := range quorumIDs { + for _, quorumID := range quorumIDs { // Check that quorum has sufficient stake - percent := GetSignedPercentage(state.OperatorState, quorumID, stakeSigned[ind]) + percent := GetSignedPercentage(state.OperatorState, quorumID, stakeSigned[quorumID]) quorumResults[quorumID] = &QuorumResult{ QuorumID: quorumID, PercentSigned: percent, @@ -205,7 +225,7 @@ func (a *StdSignatureAggregator) AggregateSignatures(ctx context.Context, state // Verify that the aggregated public key for the quorum matches the on-chain quorum aggregate public key sans non-signers of the quorum quorumAggKey := state.AggKeys[quorumID] - quorumAggPubKeys[ind] = quorumAggKey + quorumAggPubKeys[quorumID] = quorumAggKey signersAggKey := quorumAggKey.Clone() for opInd, nsk := range nonSignerKeys { @@ -215,11 +235,11 @@ func (a *StdSignatureAggregator) AggregateSignatures(ctx context.Context, state } } - if aggPubKeys[ind] == nil { + if aggPubKeys[quorumID] == nil { return nil, ErrAggSigNotValid } - ok, err := signersAggKey.VerifyEquivalence(aggPubKeys[ind]) + ok, err := signersAggKey.VerifyEquivalence(aggPubKeys[quorumID]) if err != nil { return nil, err } @@ -228,20 +248,54 @@ func (a *StdSignatureAggregator) AggregateSignatures(ctx context.Context, state } // Verify the aggregated signature for the quorum - ok = aggSigs[ind].Verify(aggPubKeys[ind], message) + ok = aggSigs[quorumID].Verify(aggPubKeys[quorumID], message) if !ok { return nil, ErrAggSigNotValid } } + return &QuorumAttestation{ + QuorumAggPubKey: quorumAggPubKeys, + SignersAggPubKey: aggPubKeys, + AggSignature: aggSigs, + QuorumResults: quorumResults, + SignerMap: signerMap, + }, nil +} + +func (a *StdSignatureAggregator) AggregateSignatures(ctx context.Context, ics IndexedChainState, referenceBlockNumber uint, quorumAttestation *QuorumAttestation, quorumIDs []QuorumID) (*SignatureAggregation, error) { // Aggregate the aggregated signatures. We reuse the first aggregated signature as the accumulator - for i := 1; i < len(aggSigs); i++ { - aggSigs[0].Add(aggSigs[i].G1Point) + var aggSig *Signature + for _, quorumID := range quorumIDs { + sig := quorumAttestation.AggSignature[quorumID] + if aggSig == nil { + aggSig = &Signature{sig.G1Point.Clone()} + } else { + aggSig.Add(sig.G1Point) + } } // Aggregate the aggregated public keys. We reuse the first aggregated public key as the accumulator - for i := 1; i < len(aggPubKeys); i++ { - aggPubKeys[0].Add(aggPubKeys[i]) + var aggPubKey *G2Point + for _, quorumID := range quorumIDs { + apk := quorumAttestation.SignersAggPubKey[quorumID] + if aggPubKey == nil { + aggPubKey = apk.Clone() + } else { + aggPubKey.Add(apk) + } + } + + nonSignerKeys := make([]*G1Point, 0) + indexedOperatorState, err := ics.GetIndexedOperatorState(ctx, referenceBlockNumber, quorumIDs) + if err != nil { + return nil, err + } + for id, op := range indexedOperatorState.IndexedOperators { + _, found := quorumAttestation.SignerMap[id] + if !found { + nonSignerKeys = append(nonSignerKeys, op.PubkeyG1) + } } // sort non signer keys according to how it's checked onchain @@ -253,13 +307,22 @@ func (a *StdSignatureAggregator) AggregateSignatures(ctx context.Context, state return bytes.Compare(hash1[:], hash2[:]) == -1 }) + quorumAggKeys := make(map[QuorumID]*G1Point, len(quorumIDs)) + for _, quorumID := range quorumIDs { + quorumAggKeys[quorumID] = quorumAttestation.QuorumAggPubKey[quorumID] + } + + quorumResults := make(map[QuorumID]*QuorumResult, len(quorumIDs)) + for _, quorumID := range quorumIDs { + quorumResults[quorumID] = quorumAttestation.QuorumResults[quorumID] + } + return &SignatureAggregation{ NonSigners: nonSignerKeys, - QuorumAggPubKeys: quorumAggPubKeys, - AggPubKey: aggPubKeys[0], - AggSignature: aggSigs[0], + QuorumAggPubKeys: quorumAggKeys, + AggPubKey: aggPubKey, + AggSignature: aggSig, QuorumResults: quorumResults, - SignerMap: signerMap, }, nil } diff --git a/core/aggregation_test.go b/core/aggregation_test.go index 61d722cc5e..9f2981d69e 100644 --- a/core/aggregation_test.go +++ b/core/aggregation_test.go @@ -164,7 +164,32 @@ func TestAggregateSignaturesStatus(t *testing.T) { quorumIDs[ind] = quorum.QuorumID } - sigAgg, err := agg.AggregateSignatures(context.Background(), state.IndexedOperatorState, quorumIDs, message, update) + numOpr := 0 + for _, quorum := range tt.quorums { + if len(dat.Stakes[quorum.QuorumID]) > numOpr { + numOpr = len(dat.Stakes[quorum.QuorumID]) + } + } + + aq, err := agg.ReceiveSignatures(context.Background(), state.IndexedOperatorState, message, update) + assert.NoError(t, err) + assert.Len(t, aq.SignerMap, numOpr-int(tt.adversaryCount)) + assert.Len(t, aq.AggSignature, 2) + assert.Len(t, aq.QuorumAggPubKey, 2) + assert.Len(t, aq.SignersAggPubKey, 2) + assert.Len(t, aq.QuorumResults, 2) + for i, q := range tt.quorums { + assert.NotNil(t, aq.AggSignature[q.QuorumID]) + assert.NotNil(t, aq.QuorumAggPubKey[q.QuorumID]) + assert.NotNil(t, aq.SignersAggPubKey[q.QuorumID]) + if tt.meetsQuorum[i] { + assert.GreaterOrEqual(t, aq.QuorumResults[q.QuorumID].PercentSigned, q.PercentSigned) + } else { + assert.Less(t, aq.QuorumResults[q.QuorumID].PercentSigned, q.PercentSigned) + } + } + + sigAgg, err := agg.AggregateSignatures(context.Background(), dat, 0, aq, quorumIDs) assert.NoError(t, err) for i, quorum := range tt.quorums { @@ -180,7 +205,6 @@ func TestAggregateSignaturesStatus(t *testing.T) { } func TestSortNonsigners(t *testing.T) { - state := dat.GetTotalOperatorState(context.Background(), 0) update := make(chan core.SigningMessage) @@ -190,7 +214,9 @@ func TestSortNonsigners(t *testing.T) { quorums := []core.QuorumID{0} - sigAgg, err := agg.AggregateSignatures(context.Background(), state.IndexedOperatorState, quorums, message, update) + aq, err := agg.ReceiveSignatures(context.Background(), state.IndexedOperatorState, message, update) + assert.NoError(t, err) + sigAgg, err := agg.AggregateSignatures(context.Background(), dat, 0, aq, quorums) assert.NoError(t, err) for i := range sigAgg.NonSigners { @@ -204,3 +230,101 @@ func TestSortNonsigners(t *testing.T) { assert.Equal(t, currHashInt.Cmp(prevHashInt), 1) } } + +func TestFilterQuorums(t *testing.T) { + allQuorums := []core.QuorumID{0, 1} + state := dat.GetTotalOperatorStateWithQuorums(context.Background(), 0, allQuorums) + + update := make(chan core.SigningMessage) + message := [32]byte{1, 2, 3, 4, 5, 6} + advCount := 4 + go simulateOperators(*state, message, update, uint(advCount)) + + numOpr := 0 + for _, quorum := range allQuorums { + if len(dat.Stakes[quorum]) > numOpr { + numOpr = len(dat.Stakes[quorum]) + } + } + + aq, err := agg.ReceiveSignatures(context.Background(), state.IndexedOperatorState, message, update) + assert.NoError(t, err) + assert.Len(t, aq.SignerMap, numOpr-advCount) + assert.Equal(t, aq.SignerMap, map[core.OperatorID]bool{ + mock.MakeOperatorId(0): true, + mock.MakeOperatorId(1): true, + }) + assert.Contains(t, aq.AggSignature, core.QuorumID(0)) + assert.Contains(t, aq.AggSignature, core.QuorumID(1)) + assert.Equal(t, aq.QuorumAggPubKey, map[core.QuorumID]*core.G1Point{ + core.QuorumID(0): state.IndexedOperatorState.AggKeys[0], + core.QuorumID(1): state.IndexedOperatorState.AggKeys[1], + }) + aggSignerPubKey0 := state.IndexedOperatorState.IndexedOperators[mock.MakeOperatorId(0)].PubkeyG2.Clone() + aggSignerPubKey0.Add(state.IndexedOperatorState.IndexedOperators[mock.MakeOperatorId(1)].PubkeyG2) + aggSignerPubKey1 := state.IndexedOperatorState.IndexedOperators[mock.MakeOperatorId(0)].PubkeyG2.Clone() + aggSignerPubKey1.Add(state.IndexedOperatorState.IndexedOperators[mock.MakeOperatorId(1)].PubkeyG2) + assert.Contains(t, aq.SignersAggPubKey, core.QuorumID(0)) + assert.Equal(t, aq.SignersAggPubKey[core.QuorumID(0)], aggSignerPubKey0) + assert.Contains(t, aq.SignersAggPubKey, core.QuorumID(1)) + assert.Equal(t, aq.SignersAggPubKey[core.QuorumID(1)], aggSignerPubKey1) + assert.Equal(t, aq.QuorumResults[core.QuorumID(0)].PercentSigned, uint8(14)) + assert.Equal(t, aq.QuorumResults[core.QuorumID(1)].PercentSigned, uint8(50)) + + // Only consider quorum 0 + quorums := []core.QuorumID{0} + sigAgg, err := agg.AggregateSignatures(context.Background(), dat, 0, aq, quorums) + assert.NoError(t, err) + assert.Len(t, sigAgg.NonSigners, 4) + assert.ElementsMatch(t, sigAgg.NonSigners, []*core.G1Point{ + state.IndexedOperatorState.IndexedOperators[mock.MakeOperatorId(2)].PubkeyG1, + state.IndexedOperatorState.IndexedOperators[mock.MakeOperatorId(3)].PubkeyG1, + state.IndexedOperatorState.IndexedOperators[mock.MakeOperatorId(4)].PubkeyG1, + state.IndexedOperatorState.IndexedOperators[mock.MakeOperatorId(5)].PubkeyG1, + }) + assert.Len(t, sigAgg.QuorumAggPubKeys, 1) + assert.Contains(t, sigAgg.QuorumAggPubKeys, core.QuorumID(0)) + assert.Equal(t, sigAgg.QuorumAggPubKeys[0], state.IndexedOperatorState.AggKeys[0]) + + assert.Equal(t, sigAgg.AggPubKey, aggSignerPubKey0) + expectedAggSignerKey := sigAgg.QuorumAggPubKeys[0].Clone() + for _, nsk := range sigAgg.NonSigners { + expectedAggSignerKey.Sub(nsk) + } + ok, err := expectedAggSignerKey.VerifyEquivalence(sigAgg.AggPubKey) + assert.NoError(t, err) + assert.True(t, ok) + ok = sigAgg.AggSignature.Verify(sigAgg.AggPubKey, message) + assert.True(t, ok) + assert.Len(t, sigAgg.QuorumResults, 1) + assert.Contains(t, sigAgg.QuorumResults, core.QuorumID(0)) + assert.Equal(t, sigAgg.QuorumResults[0].QuorumID, core.QuorumID(0)) + assert.Equal(t, sigAgg.QuorumResults[0].PercentSigned, core.QuorumID(14)) + + // Only consider quorum 1 + quorums = []core.QuorumID{1} + sigAgg, err = agg.AggregateSignatures(context.Background(), dat, 0, aq, quorums) + assert.NoError(t, err) + assert.Len(t, sigAgg.NonSigners, 1) + assert.ElementsMatch(t, sigAgg.NonSigners, []*core.G1Point{ + state.IndexedOperatorState.IndexedOperators[mock.MakeOperatorId(2)].PubkeyG1, + }) + assert.Len(t, sigAgg.QuorumAggPubKeys, 1) + assert.Contains(t, sigAgg.QuorumAggPubKeys, core.QuorumID(1)) + assert.Equal(t, sigAgg.QuorumAggPubKeys[1], state.IndexedOperatorState.AggKeys[1]) + + assert.Equal(t, sigAgg.AggPubKey, aggSignerPubKey1) + expectedAggSignerKey = sigAgg.QuorumAggPubKeys[1].Clone() + for _, nsk := range sigAgg.NonSigners { + expectedAggSignerKey.Sub(nsk) + } + ok, err = expectedAggSignerKey.VerifyEquivalence(sigAgg.AggPubKey) + assert.NoError(t, err) + assert.True(t, ok) + ok = sigAgg.AggSignature.Verify(sigAgg.AggPubKey, message) + assert.True(t, ok) + assert.Len(t, sigAgg.QuorumResults, 1) + assert.Contains(t, sigAgg.QuorumResults, core.QuorumID(1)) + assert.Equal(t, sigAgg.QuorumResults[1].QuorumID, core.QuorumID(1)) + assert.Equal(t, sigAgg.QuorumResults[1].PercentSigned, core.QuorumID(50)) +} diff --git a/core/mock/state.go b/core/mock/state.go index 324febd621..49ef526aed 100644 --- a/core/mock/state.go +++ b/core/mock/state.go @@ -190,6 +190,7 @@ func (d *ChainDataMock) GetTotalOperatorStateWithQuorums(ctx context.Context, bl BlockNumber: blockNumber, } + filteredIndexedOperators := make(map[core.OperatorID]*core.IndexedOperatorInfo, 0) for quorumID, operatorsByID := range storedOperators { for opID := range operatorsByID { if aggPubKeys[quorumID] == nil { @@ -198,12 +199,13 @@ func (d *ChainDataMock) GetTotalOperatorStateWithQuorums(ctx context.Context, bl } else { aggPubKeys[quorumID].Add(privateOperators[opID].KeyPair.GetPubKeyG1()) } + filteredIndexedOperators[opID] = indexedOperators[opID] } } indexedState := &core.IndexedOperatorState{ OperatorState: operatorState, - IndexedOperators: indexedOperators, + IndexedOperators: filteredIndexedOperators, AggKeys: make(map[core.QuorumID]*core.G1Point), } for quorumID, apk := range aggPubKeys { diff --git a/core/mock/tx.go b/core/mock/tx.go index f2ba688d91..b56c98b93d 100644 --- a/core/mock/tx.go +++ b/core/mock/tx.go @@ -90,7 +90,7 @@ func (t *MockTransactor) GetOperatorStakesForQuorums(ctx context.Context, quorum } func (t *MockTransactor) BuildConfirmBatchTxn(ctx context.Context, batchHeader *core.BatchHeader, quorums map[core.QuorumID]*core.QuorumResult, signatureAggregation *core.SignatureAggregation) (*types.Transaction, error) { - args := t.Called() + args := t.Called(ctx, batchHeader, quorums, signatureAggregation) result := args.Get(0) return result.(*types.Transaction), args.Error(1) } diff --git a/disperser/batcher/batcher.go b/disperser/batcher/batcher.go index 738deeb588..11308f4000 100644 --- a/disperser/batcher/batcher.go +++ b/disperser/batcher/batcher.go @@ -7,7 +7,6 @@ import ( "fmt" "math" "math/big" - "slices" "time" "github.com/Layr-Labs/eigenda/common" @@ -441,21 +440,12 @@ func (b *Batcher) HandleSingleBatch(ctx context.Context) error { // Aggregate the signatures log.Debug("Aggregating signatures...") - // construct quorumParams - quorumIDs := make([]core.QuorumID, 0, len(batch.State.AggKeys)) - for quorumID := range batch.State.Operators { - quorumIDs = append(quorumIDs, quorumID) - } - slices.Sort(quorumIDs) - stageTimer = time.Now() - aggSig, err := b.Aggregator.AggregateSignatures(ctx, batch.State, quorumIDs, headerHash, update) + quorumAttestation, err := b.Aggregator.ReceiveSignatures(ctx, batch.State, headerHash, update) if err != nil { _ = b.handleFailure(ctx, batch.BlobMetadata, FailAggregateSignatures) - return fmt.Errorf("HandleSingleBatch: error aggregating signatures: %w", err) + return fmt.Errorf("HandleSingleBatch: error receiving and validating signatures: %w", err) } - log.Debug("AggregateSignatures took", "duration", time.Since(stageTimer)) - b.Metrics.ObserveLatency("AggregateSignatures", float64(time.Since(stageTimer).Milliseconds())) operatorCount := make(map[core.QuorumID]int) signerCount := make(map[core.QuorumID]int) for quorumID, opState := range batch.State.Operators { @@ -464,23 +454,39 @@ func (b *Batcher) HandleSingleBatch(ctx context.Context) error { signerCount[quorumID] = 0 } for opID := range opState { - if _, ok := aggSig.SignerMap[opID]; ok { + if _, ok := quorumAttestation.SignerMap[opID]; ok { signerCount[quorumID]++ } } } - b.Metrics.UpdateAttestation(operatorCount, signerCount, aggSig.QuorumResults) - for _, quorumResult := range aggSig.QuorumResults { + b.Metrics.UpdateAttestation(operatorCount, signerCount, quorumAttestation.QuorumResults) + for _, quorumResult := range quorumAttestation.QuorumResults { log.Info("Aggregated quorum result", "quorumID", quorumResult.QuorumID, "percentSigned", quorumResult.PercentSigned) } - numPassed := numBlobsAttested(aggSig.QuorumResults, batch.BlobHeaders) + numPassed, passedQuorums := numBlobsAttestedByQuorum(quorumAttestation.QuorumResults, batch.BlobHeaders) // TODO(mooselumph): Determine whether to confirm the batch based on the number of successes if numPassed == 0 { _ = b.handleFailure(ctx, batch.BlobMetadata, FailNoSignatures) return errors.New("HandleSingleBatch: no blobs received sufficient signatures") } + nonEmptyQuorums := []core.QuorumID{} + for quorumID := range passedQuorums { + log.Info("Quorums successfully attested", "quorumID", quorumID) + nonEmptyQuorums = append(nonEmptyQuorums, quorumID) + } + + // Aggregate the signatures across only the non-empty quorums. Excluding empty quorums reduces the gas cost. + aggSig, err := b.Aggregator.AggregateSignatures(ctx, b.ChainState, batch.BatchHeader.ReferenceBlockNumber, quorumAttestation, nonEmptyQuorums) + if err != nil { + _ = b.handleFailure(ctx, batch.BlobMetadata, FailAggregateSignatures) + return fmt.Errorf("HandleSingleBatch: error aggregating signatures: %w", err) + } + + log.Debug("AggregateSignatures took", "duration", time.Since(stageTimer)) + b.Metrics.ObserveLatency("AggregateSignatures", float64(time.Since(stageTimer).Milliseconds())) + // Confirm the batch log.Debug("Confirming batch...") @@ -588,15 +594,19 @@ func (b *Batcher) getBatchID(ctx context.Context, txReceipt *types.Receipt) (uin return batchID, nil } -// numBlobsAttested returns the number of blobs that have been successfully attested by the given quorums -func numBlobsAttested(signedQuorums map[core.QuorumID]*core.QuorumResult, headers []*core.BlobHeader) int { +// numBlobsAttestedByQuorum returns two values: +// 1. the number of blobs that have been successfully attested by all quorums +// 2. map[QuorumID]struct{} contains quorums that have been successfully attested by the quorum (has at least one blob attested in the quorum) +func numBlobsAttestedByQuorum(signedQuorums map[core.QuorumID]*core.QuorumResult, headers []*core.BlobHeader) (int, map[core.QuorumID]struct{}) { numPassed := 0 + quorums := make(map[core.QuorumID]struct{}) for _, blob := range headers { thisPassed := true for _, quorum := range blob.QuorumInfos { if signedQuorums[quorum.QuorumID].PercentSigned < quorum.ConfirmationThreshold { thisPassed = false - break + } else { + quorums[quorum.QuorumID] = struct{}{} } } if thisPassed { @@ -604,7 +614,7 @@ func numBlobsAttested(signedQuorums map[core.QuorumID]*core.QuorumResult, header } } - return numPassed + return numPassed, quorums } func isBlobAttested(signedQuorums map[core.QuorumID]*core.QuorumResult, header *core.BlobHeader) bool { diff --git a/disperser/batcher/batcher_test.go b/disperser/batcher/batcher_test.go index 97ddbbe578..dafe82b61a 100644 --- a/disperser/batcher/batcher_test.go +++ b/disperser/batcher/batcher_test.go @@ -18,8 +18,8 @@ import ( coremock "github.com/Layr-Labs/eigenda/core/mock" "github.com/Layr-Labs/eigenda/disperser" bat "github.com/Layr-Labs/eigenda/disperser/batcher" - "github.com/Layr-Labs/eigenda/disperser/batcher/mock" batchermock "github.com/Layr-Labs/eigenda/disperser/batcher/mock" + batmock "github.com/Layr-Labs/eigenda/disperser/batcher/mock" "github.com/Layr-Labs/eigenda/disperser/common/inmem" dmock "github.com/Layr-Labs/eigenda/disperser/mock" "github.com/Layr-Labs/eigenda/encoding" @@ -29,6 +29,7 @@ import ( gethcommon "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core/types" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" ) var ( @@ -43,6 +44,8 @@ type batcherComponents struct { encoderClient *disperser.LocalEncoderClient encodingStreamer *bat.EncodingStreamer ethClient *cmock.MockEthClient + dispatcher *dmock.Dispatcher + chainData *coremock.ChainDataMock } // makeTestEncoder makes an encoder currently using the only supported backend. @@ -72,15 +75,17 @@ func makeTestBlob(securityParams []*core.SecurityParam) core.Blob { func makeBatcher(t *testing.T) (*batcherComponents, *bat.Batcher, func() []time.Time) { // Common Components + // logger, err := common.NewLogger(common.DefaultLoggerConfig()) + // assert.NoError(t, err) logger := logging.NewNoopLogger() finalizationBlockDelay := uint(75) // Core Components cst, err := coremock.MakeChainDataMock(map[uint8]int{ - 0: 10, - 1: 10, - 2: 10, + 0: 4, + 1: 4, + 2: 6, }) assert.NoError(t, err) cst.On("GetCurrentBlockNumber").Return(uint(10)+finalizationBlockDelay, nil) @@ -121,7 +126,7 @@ func makeBatcher(t *testing.T) (*batcherComponents, *bat.Batcher, func() []time. encoderClient := disperser.NewLocalEncoderClient(p) finalizer := batchermock.NewFinalizer() ethClient := &cmock.MockEthClient{} - txnManager := mock.NewTxnManager() + txnManager := batmock.NewTxnManager() b, err := bat.NewBatcher(config, timeoutConfig, blobStore, dispatcher, cst, asgn, encoderClient, agg, ethClient, finalizer, transactor, txnManager, logger, metrics, handleBatchLivenessChan) assert.NoError(t, err) @@ -151,6 +156,8 @@ func makeBatcher(t *testing.T) (*batcherComponents, *bat.Batcher, func() []time. encoderClient: encoderClient, encodingStreamer: b.EncodingStreamer, ethClient: ethClient, + dispatcher: dispatcher, + chainData: cst, }, b, func() []time.Time { close(doneListening) // Stop the goroutine listening to heartbeats @@ -180,6 +187,7 @@ func TestBatcherIterations(t *testing.T) { ConfirmationThreshold: 100, }}) components, batcher, getHeartbeats := makeBatcher(t) + components.dispatcher.On("DisperseBatch").Return(map[core.OperatorID]struct{}{}) defer func() { heartbeats := getHeartbeats() @@ -222,7 +230,28 @@ func TestBatcherIterations(t *testing.T) { assert.Equal(t, uint64(24576), size) // Robert checks it txn := types.NewTransaction(0, gethcommon.Address{}, big.NewInt(0), 0, big.NewInt(0), nil) - components.transactor.On("BuildConfirmBatchTxn").Return(txn, nil) + components.transactor.On("BuildConfirmBatchTxn", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + quorumResults := args[2].(map[core.QuorumID]*core.QuorumResult) + assert.Len(t, quorumResults, 2) + assert.Contains(t, quorumResults, core.QuorumID(0)) + assert.Contains(t, quorumResults, core.QuorumID(1)) + + aggSig := args[3].(*core.SignatureAggregation) + assert.Empty(t, aggSig.NonSigners) + assert.Len(t, aggSig.QuorumAggPubKeys, 2) + assert.Contains(t, aggSig.QuorumAggPubKeys, core.QuorumID(0)) + assert.Contains(t, aggSig.QuorumAggPubKeys, core.QuorumID(1)) + assert.Equal(t, aggSig.QuorumResults, map[core.QuorumID]*core.QuorumResult{ + core.QuorumID(0): { + QuorumID: core.QuorumID(0), + PercentSigned: uint8(100), + }, + core.QuorumID(1): { + QuorumID: core.QuorumID(1), + PercentSigned: uint8(100), + }, + }) + }).Return(txn, nil) components.txnManager.On("ProcessTransaction").Return(nil) err = batcher.HandleSingleBatch(ctx) @@ -278,6 +307,7 @@ func TestBlobFailures(t *testing.T) { }}) components, batcher, getHeartbeats := makeBatcher(t) + components.dispatcher.On("DisperseBatch").Return(map[core.OperatorID]struct{}{}) defer func() { heartbeats := getHeartbeats() @@ -297,7 +327,7 @@ func TestBlobFailures(t *testing.T) { assert.NoError(t, err) txn := types.NewTransaction(0, gethcommon.Address{}, big.NewInt(0), 0, big.NewInt(0), nil) - components.transactor.On("BuildConfirmBatchTxn").Return(txn, nil) + components.transactor.On("BuildConfirmBatchTxn", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(txn, nil) components.txnManager.On("ProcessTransaction").Return(nil) // Test with receipt response with error @@ -384,6 +414,7 @@ func TestBlobRetry(t *testing.T) { }}) components, batcher, getHeartbeats := makeBatcher(t) + components.dispatcher.On("DisperseBatch").Return(map[core.OperatorID]struct{}{}) defer func() { heartbeats := getHeartbeats() @@ -406,7 +437,7 @@ func TestBlobRetry(t *testing.T) { assert.NotNil(t, encodedResult) txn := types.NewTransaction(0, gethcommon.Address{}, big.NewInt(0), 0, big.NewInt(0), nil) - components.transactor.On("BuildConfirmBatchTxn").Return(txn, nil) + components.transactor.On("BuildConfirmBatchTxn", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(txn, nil) components.txnManager.On("ProcessTransaction").Return(nil) err = batcher.HandleSingleBatch(ctx) @@ -490,6 +521,7 @@ func TestRetryTxnReceipt(t *testing.T) { ConfirmationThreshold: 100, }}) components, batcher, getHeartbeats := makeBatcher(t) + components.dispatcher.On("DisperseBatch").Return(map[core.OperatorID]struct{}{}) defer func() { heartbeats := getHeartbeats() @@ -534,7 +566,7 @@ func TestRetryTxnReceipt(t *testing.T) { assert.NoError(t, err) txn := types.NewTransaction(0, gethcommon.Address{}, big.NewInt(0), 0, big.NewInt(0), nil) - components.transactor.On("BuildConfirmBatchTxn").Return(txn, nil) + components.transactor.On("BuildConfirmBatchTxn", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(txn, nil) components.txnManager.On("ProcessTransaction").Return(nil) err = batcher.HandleSingleBatch(ctx) @@ -554,3 +586,168 @@ func TestRetryTxnReceipt(t *testing.T) { assert.Equal(t, meta.ConfirmationInfo.BatchID, uint32(3)) components.ethClient.AssertNumberOfCalls(t, "TransactionReceipt", 3) } + +// TestBlobAttestationFailures tests a case where the attestation fails for all blobs in one quorum, +// in which case the quorum should be omitted from the confirmation transaction. +func TestBlobAttestationFailures(t *testing.T) { + blob0 := makeTestBlob([]*core.SecurityParam{ + { + QuorumID: 0, + AdversaryThreshold: 80, + ConfirmationThreshold: 100, + }, + { + QuorumID: 1, + AdversaryThreshold: 80, + ConfirmationThreshold: 100, + }, + }) + + blob1 := makeTestBlob([]*core.SecurityParam{ + { + QuorumID: 0, + AdversaryThreshold: 80, + ConfirmationThreshold: 100, + }, + { + QuorumID: 1, + AdversaryThreshold: 80, + ConfirmationThreshold: 100, + }, + { + QuorumID: 2, + AdversaryThreshold: 80, + ConfirmationThreshold: 100, + }, + }) + + components, batcher, _ := makeBatcher(t) + + blobStore := components.blobStore + ctx := context.Background() + _, _ = queueBlob(t, ctx, &blob0, blobStore) + _, _ = queueBlob(t, ctx, &blob1, blobStore) + + // Start the batcher + out := make(chan bat.EncodingResultOrStatus) + err := components.encodingStreamer.RequestEncoding(ctx, out) + assert.NoError(t, err) + err = components.encodingStreamer.ProcessEncodedBlobs(ctx, <-out) + assert.NoError(t, err) + err = components.encodingStreamer.ProcessEncodedBlobs(ctx, <-out) + assert.NoError(t, err) + err = components.encodingStreamer.ProcessEncodedBlobs(ctx, <-out) + assert.NoError(t, err) + err = components.encodingStreamer.ProcessEncodedBlobs(ctx, <-out) + assert.NoError(t, err) + err = components.encodingStreamer.ProcessEncodedBlobs(ctx, <-out) + assert.NoError(t, err) + + components.dispatcher.On("DisperseBatch").Return(map[core.OperatorID]struct{}{ + // operator 5 is only in quorum 2 + coremock.MakeOperatorId(5): {}, + }) + + txn := types.NewTransaction(0, gethcommon.Address{}, big.NewInt(0), 0, big.NewInt(0), nil) + components.transactor.On("BuildConfirmBatchTxn", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + quorumResults := args[2].(map[core.QuorumID]*core.QuorumResult) + assert.Len(t, quorumResults, 2) + assert.Contains(t, quorumResults, core.QuorumID(0)) + assert.Contains(t, quorumResults, core.QuorumID(1)) + // should not contain quorum 2 + assert.NotContains(t, quorumResults, core.QuorumID(2)) + + aggSig := args[3].(*core.SignatureAggregation) + assert.Empty(t, aggSig.NonSigners) + assert.NotContains(t, aggSig.QuorumAggPubKeys, core.QuorumID(2)) + assert.NotContains(t, aggSig.QuorumResults, core.QuorumID(2)) + }).Return(txn, nil) + components.txnManager.On("ProcessTransaction").Return(nil) + + // Test with receipt response with error + err = batcher.HandleSingleBatch(ctx) + assert.NoError(t, err) +} + +// TestBlobAttestationFailures2 tests a case where the attestation fails for some blobs in one quorum, +// in which case the quorum should not be omitted from the confirmation transaction. +func TestBlobAttestationFailures2(t *testing.T) { + blob0 := makeTestBlob([]*core.SecurityParam{ + { + QuorumID: 0, + AdversaryThreshold: 80, + ConfirmationThreshold: 100, + }, + { + QuorumID: 2, + AdversaryThreshold: 80, + ConfirmationThreshold: 50, + }, + }) + + blob1 := makeTestBlob([]*core.SecurityParam{ + { + QuorumID: 0, + AdversaryThreshold: 80, + ConfirmationThreshold: 100, + }, + { + QuorumID: 2, + AdversaryThreshold: 80, + ConfirmationThreshold: 100, + }, + }) + + components, batcher, _ := makeBatcher(t) + + blobStore := components.blobStore + ctx := context.Background() + _, _ = queueBlob(t, ctx, &blob0, blobStore) + _, _ = queueBlob(t, ctx, &blob1, blobStore) + + // Start the batcher + out := make(chan bat.EncodingResultOrStatus) + err := components.encodingStreamer.RequestEncoding(ctx, out) + assert.NoError(t, err) + err = components.encodingStreamer.ProcessEncodedBlobs(ctx, <-out) + assert.NoError(t, err) + err = components.encodingStreamer.ProcessEncodedBlobs(ctx, <-out) + assert.NoError(t, err) + err = components.encodingStreamer.ProcessEncodedBlobs(ctx, <-out) + assert.NoError(t, err) + err = components.encodingStreamer.ProcessEncodedBlobs(ctx, <-out) + assert.NoError(t, err) + + components.dispatcher.On("DisperseBatch").Return(map[core.OperatorID]struct{}{ + // this operator is only in quorum 2 + coremock.MakeOperatorId(5): {}, + }) + + txn := types.NewTransaction(0, gethcommon.Address{}, big.NewInt(0), 0, big.NewInt(0), nil) + components.transactor.On("BuildConfirmBatchTxn", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + quorumResults := args[2].(map[core.QuorumID]*core.QuorumResult) + assert.Len(t, quorumResults, 2) + assert.Contains(t, quorumResults, core.QuorumID(0)) + assert.Contains(t, quorumResults, core.QuorumID(2)) + + aggSig := args[3].(*core.SignatureAggregation) + assert.Len(t, aggSig.NonSigners, 1) + assert.Contains(t, aggSig.QuorumAggPubKeys, core.QuorumID(0)) + assert.Contains(t, aggSig.QuorumAggPubKeys, core.QuorumID(2)) + assert.Equal(t, aggSig.QuorumResults, map[core.QuorumID]*core.QuorumResult{ + core.QuorumID(0): { + QuorumID: core.QuorumID(0), + PercentSigned: uint8(100), + }, + core.QuorumID(2): { + QuorumID: core.QuorumID(2), + PercentSigned: uint8(71), + }, + }) + }).Return(txn, nil) + components.txnManager.On("ProcessTransaction").Return(nil) + + // Test with receipt response with error + err = batcher.HandleSingleBatch(ctx) + assert.NoError(t, err) +} diff --git a/disperser/mock/dispatcher.go b/disperser/mock/dispatcher.go index 980bf0c9fd..59d13f686e 100644 --- a/disperser/mock/dispatcher.go +++ b/disperser/mock/dispatcher.go @@ -2,25 +2,33 @@ package mock import ( "context" + "errors" "github.com/Layr-Labs/eigenda/core" - "github.com/Layr-Labs/eigenda/core/mock" + coremock "github.com/Layr-Labs/eigenda/core/mock" "github.com/Layr-Labs/eigenda/disperser" + "github.com/stretchr/testify/mock" ) type Dispatcher struct { - state *mock.PrivateOperatorState + mock.Mock + state *coremock.PrivateOperatorState } var _ disperser.Dispatcher = (*Dispatcher)(nil) -func NewDispatcher(state *mock.PrivateOperatorState) disperser.Dispatcher { +func NewDispatcher(state *coremock.PrivateOperatorState) *Dispatcher { return &Dispatcher{ state: state, } } func (d *Dispatcher) DisperseBatch(ctx context.Context, state *core.IndexedOperatorState, blobs []core.EncodedBlob, header *core.BatchHeader) chan core.SigningMessage { + args := d.Called() + var nonSigners map[core.OperatorID]struct{} + if args.Get(0) != nil { + nonSigners = args.Get(0).(map[core.OperatorID]struct{}) + } update := make(chan core.SigningMessage) message, err := header.GetBatchHeaderHash() if err != nil { @@ -34,13 +42,22 @@ func (d *Dispatcher) DisperseBatch(ctx context.Context, state *core.IndexedOpera } go func() { - for id, op := range d.state.PrivateOperators { - sig := op.KeyPair.SignMessage(message) + for id := range state.IndexedOperators { + info := d.state.PrivateOperators[id] + if _, ok := nonSigners[id]; ok { + update <- core.SigningMessage{ + Signature: nil, + Operator: id, + Err: errors.New("not a signer"), + } + } else { + sig := info.KeyPair.SignMessage(message) - update <- core.SigningMessage{ - Signature: sig, - Operator: id, - Err: nil, + update <- core.SigningMessage{ + Signature: sig, + Operator: id, + Err: nil, + } } } }() diff --git a/test/integration_test.go b/test/integration_test.go index c75afeee0b..a1fc47c04b 100644 --- a/test/integration_test.go +++ b/test/integration_test.go @@ -53,6 +53,7 @@ import ( "github.com/ethereum/go-ethereum/core/types" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" ) var ( @@ -402,7 +403,7 @@ func TestDispersalAndRetrieval(t *testing.T) { dis.batcher.EncodingStreamer.Pool.StopWait() txn := types.NewTransaction(0, gethcommon.Address{}, big.NewInt(0), 0, big.NewInt(0), nil) - dis.transactor.On("BuildConfirmBatchTxn").Return(txn, nil) + dis.transactor.On("BuildConfirmBatchTxn", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(txn, nil) dis.txnManager.On("ProcessTransaction").Return(nil) err = dis.batcher.HandleSingleBatch(ctx)