diff --git a/disperser/batcher/batcher.go b/disperser/batcher/batcher.go index b54b87a424..e55ce932f7 100644 --- a/disperser/batcher/batcher.go +++ b/disperser/batcher/batcher.go @@ -146,8 +146,26 @@ func NewBatcher( }, nil } +func (b *Batcher) RecoverState(ctx context.Context) error { + metas, err := b.Queue.GetBlobMetadataByStatus(ctx, disperser.Dispersing) + if err != nil { + return fmt.Errorf("failed to get blobs in dispersing state: %w", err) + } + for _, meta := range metas { + err = b.Queue.MarkBlobProcessing(ctx, meta.GetBlobKey()) + if err != nil { + return fmt.Errorf("failed to mark blob (%s) as processing: %w", meta.GetBlobKey(), err) + } + } + return nil +} + func (b *Batcher) Start(ctx context.Context) error { - err := b.ChainState.Start(ctx) + err := b.RecoverState(ctx) + if err != nil { + return fmt.Errorf("failed to recover state: %w", err) + } + err = b.ChainState.Start(ctx) if err != nil { return err } diff --git a/disperser/batcher/batcher_test.go b/disperser/batcher/batcher_test.go index dafe82b61a..29786e9205 100644 --- a/disperser/batcher/batcher_test.go +++ b/disperser/batcher/batcher_test.go @@ -751,3 +751,59 @@ func TestBlobAttestationFailures2(t *testing.T) { err = batcher.HandleSingleBatch(ctx) assert.NoError(t, err) } + +func TestBatcherRecoverState(t *testing.T) { + blob0 := makeTestBlob([]*core.SecurityParam{ + { + QuorumID: 0, + AdversaryThreshold: 80, + ConfirmationThreshold: 100, + }, + { + QuorumID: 2, + AdversaryThreshold: 80, + ConfirmationThreshold: 50, + }, + }) + + blob1 := makeTestBlob([]*core.SecurityParam{ + { + QuorumID: 0, + AdversaryThreshold: 80, + ConfirmationThreshold: 100, + }, + { + QuorumID: 2, + AdversaryThreshold: 80, + ConfirmationThreshold: 100, + }, + }) + + components, batcher, _ := makeBatcher(t) + + blobStore := components.blobStore + ctx := context.Background() + _, key1 := queueBlob(t, ctx, &blob0, blobStore) + _, _ = queueBlob(t, ctx, &blob1, blobStore) + + err := blobStore.MarkBlobDispersing(ctx, key1) + assert.NoError(t, err) + processingBlobs, err := blobStore.GetBlobMetadataByStatus(ctx, disperser.Processing) + assert.NoError(t, err) + assert.Len(t, processingBlobs, 1) + + dispersingBlobs, err := blobStore.GetBlobMetadataByStatus(ctx, disperser.Dispersing) + assert.NoError(t, err) + assert.Len(t, dispersingBlobs, 1) + + err = batcher.RecoverState(context.Background()) + assert.NoError(t, err) + + processingBlobs, err = blobStore.GetBlobMetadataByStatus(ctx, disperser.Processing) + assert.NoError(t, err) + assert.Len(t, processingBlobs, 2) + + dispersingBlobs, err = blobStore.GetBlobMetadataByStatus(ctx, disperser.Dispersing) + assert.NoError(t, err) + assert.Len(t, dispersingBlobs, 0) +}