diff --git a/blob/service_test.go b/blob/service_test.go index 2cbfcf7b03..8c55d8824f 100644 --- a/blob/service_test.go +++ b/blob/service_test.go @@ -145,9 +145,11 @@ func TestBlobService_Get(t *testing.T) { shareOffset := 0 for i := range blobs { row, col := calculateIndex(len(h.DAH.RowRoots), blobs[i].index) - sh, err := service.shareGetter.GetShare(ctx, h, row, col) + idx := shwap.SampleCoords{Row: row, Col: col} require.NoError(t, err) - require.True(t, bytes.Equal(sh.ToBytes(), resultShares[shareOffset].ToBytes()), + smpls, err := service.shareGetter.GetSamples(ctx, h, []shwap.SampleCoords{idx}) + require.NoError(t, err) + require.True(t, bytes.Equal(smpls[0].Share.ToBytes(), resultShares[shareOffset].ToBytes()), fmt.Sprintf("issue on %d attempt. ROW:%d, COL: %d, blobIndex:%d", i, row, col, blobs[i].index), ) shareOffset += libshare.SparseSharesNeeded(uint32(len(blobs[i].Data()))) @@ -487,10 +489,13 @@ func TestService_GetSingleBlobWithoutPadding(t *testing.T) { h, err := service.headerGetter(ctx, 1) require.NoError(t, err) row, col := calculateIndex(len(h.DAH.RowRoots), newBlob.index) - sh, err := service.shareGetter.GetShare(ctx, h, row, col) + idx := shwap.SampleCoords{Row: row, Col: col} + require.NoError(t, err) + + smpls, err := service.shareGetter.GetSamples(ctx, h, []shwap.SampleCoords{idx}) require.NoError(t, err) - assert.Equal(t, sh, resultShares[0]) + assert.Equal(t, smpls[0].Share, resultShares[0]) } func TestService_Get(t *testing.T) { @@ -521,10 +526,13 @@ func TestService_Get(t *testing.T) { assert.Equal(t, b.Commitment, blob.Commitment) row, col := calculateIndex(len(h.DAH.RowRoots), b.index) - sh, err := service.shareGetter.GetShare(ctx, h, row, col) + idx := shwap.SampleCoords{Row: row, Col: col} require.NoError(t, err) - assert.Equal(t, sh, resultShares[shareOffset], fmt.Sprintf("issue on %d attempt", i)) + smpls, err := service.shareGetter.GetSamples(ctx, h, []shwap.SampleCoords{idx}) + require.NoError(t, err) + + assert.Equal(t, smpls[0].Share, resultShares[shareOffset], fmt.Sprintf("issue on %d attempt", i)) shareOffset += libshare.SparseSharesNeeded(uint32(len(blob.Data()))) } } @@ -580,10 +588,13 @@ func TestService_GetAllWithoutPadding(t *testing.T) { require.True(t, blobs[i].compareCommitments(blob.Commitment)) row, col := calculateIndex(len(h.DAH.RowRoots), blob.index) - sh, err := service.shareGetter.GetShare(ctx, h, row, col) + idx := shwap.SampleCoords{Row: row, Col: col} + require.NoError(t, err) + + smpls, err := service.shareGetter.GetSamples(ctx, h, []shwap.SampleCoords{idx}) require.NoError(t, err) - assert.Equal(t, sh, resultShares[shareOffset]) + assert.Equal(t, smpls[0].Share, resultShares[shareOffset]) shareOffset += libshare.SparseSharesNeeded(uint32(len(blob.Data()))) } } @@ -902,10 +913,12 @@ func createService(ctx context.Context, t testing.TB, shares []libshare.Share) * nd, err := eds.NamespaceData(ctx, accessor, ns) return nd, err }) - shareGetter.EXPECT().GetShare(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes(). - DoAndReturn(func(ctx context.Context, h *header.ExtendedHeader, row, col int) (libshare.Share, error) { - s, err := accessor.Sample(ctx, row, col) - return s.Share, err + shareGetter.EXPECT().GetSamples(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes(). + DoAndReturn(func(ctx context.Context, h *header.ExtendedHeader, + indices []shwap.SampleCoords, + ) ([]shwap.Sample, error) { + smpl, err := accessor.Sample(ctx, indices[0]) + return []shwap.Sample{smpl}, err }) // create header and put it into the store diff --git a/nodebuilder/share/mocks/api.go b/nodebuilder/share/mocks/api.go index cccc81a452..7fde2338cc 100644 --- a/nodebuilder/share/mocks/api.go +++ b/nodebuilder/share/mocks/api.go @@ -8,6 +8,7 @@ import ( context "context" reflect "reflect" + header "github.com/celestiaorg/celestia-node/header" share "github.com/celestiaorg/celestia-node/nodebuilder/share" shwap "github.com/celestiaorg/celestia-node/share/shwap" share0 "github.com/celestiaorg/go-square/v2/share" @@ -53,6 +54,21 @@ func (mr *MockModuleMockRecorder) GetEDS(arg0, arg1 interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEDS", reflect.TypeOf((*MockModule)(nil).GetEDS), arg0, arg1) } +// GetNamespaceData mocks base method. +func (m *MockModule) GetNamespaceData(arg0 context.Context, arg1 uint64, arg2 share0.Namespace) (shwap.NamespaceData, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetNamespaceData", arg0, arg1, arg2) + ret0, _ := ret[0].(shwap.NamespaceData) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetNamespaceData indicates an expected call of GetNamespaceData. +func (mr *MockModuleMockRecorder) GetNamespaceData(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNamespaceData", reflect.TypeOf((*MockModule)(nil).GetNamespaceData), arg0, arg1, arg2) +} + // GetRange mocks base method. func (m *MockModule) GetRange(arg0 context.Context, arg1 uint64, arg2, arg3 int) (*share.GetRangeResult, error) { m.ctrl.T.Helper() @@ -68,34 +84,34 @@ func (mr *MockModuleMockRecorder) GetRange(arg0, arg1, arg2, arg3 interface{}) * return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRange", reflect.TypeOf((*MockModule)(nil).GetRange), arg0, arg1, arg2, arg3) } -// GetShare mocks base method. -func (m *MockModule) GetShare(arg0 context.Context, arg1 uint64, arg2, arg3 int) (share0.Share, error) { +// GetSamples mocks base method. +func (m *MockModule) GetSamples(arg0 context.Context, arg1 *header.ExtendedHeader, arg2 []shwap.SampleCoords) ([]shwap.Sample, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetShare", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(share0.Share) + ret := m.ctrl.Call(m, "GetSamples", arg0, arg1, arg2) + ret0, _ := ret[0].([]shwap.Sample) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetShare indicates an expected call of GetShare. -func (mr *MockModuleMockRecorder) GetShare(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +// GetSamples indicates an expected call of GetSamples. +func (mr *MockModuleMockRecorder) GetSamples(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetShare", reflect.TypeOf((*MockModule)(nil).GetShare), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSamples", reflect.TypeOf((*MockModule)(nil).GetSamples), arg0, arg1, arg2) } -// GetSharesByNamespace mocks base method. -func (m *MockModule) GetSharesByNamespace(arg0 context.Context, arg1 uint64, arg2 share0.Namespace) (shwap.NamespaceData, error) { +// GetShare mocks base method. +func (m *MockModule) GetShare(arg0 context.Context, arg1 uint64, arg2, arg3 int) (share0.Share, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetSharesByNamespace", arg0, arg1, arg2) - ret0, _ := ret[0].(shwap.NamespaceData) + ret := m.ctrl.Call(m, "GetShare", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(share0.Share) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetSharesByNamespace indicates an expected call of GetSharesByNamespace. -func (mr *MockModuleMockRecorder) GetSharesByNamespace(arg0, arg1, arg2 interface{}) *gomock.Call { +// GetShare indicates an expected call of GetShare. +func (mr *MockModuleMockRecorder) GetShare(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSharesByNamespace", reflect.TypeOf((*MockModule)(nil).GetSharesByNamespace), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetShare", reflect.TypeOf((*MockModule)(nil).GetShare), arg0, arg1, arg2, arg3) } // SharesAvailable mocks base method. diff --git a/nodebuilder/share/share.go b/nodebuilder/share/share.go index 8a3efcc757..783018fe47 100644 --- a/nodebuilder/share/share.go +++ b/nodebuilder/share/share.go @@ -8,6 +8,7 @@ import ( libshare "github.com/celestiaorg/go-square/v2/share" "github.com/celestiaorg/rsmt2d" + "github.com/celestiaorg/celestia-node/header" headerServ "github.com/celestiaorg/celestia-node/nodebuilder/header" "github.com/celestiaorg/celestia-node/share" "github.com/celestiaorg/celestia-node/share/eds" @@ -45,6 +46,8 @@ type Module interface { SharesAvailable(ctx context.Context, height uint64) error // GetShare gets a Share by coordinates in EDS. GetShare(ctx context.Context, height uint64, row, col int) (libshare.Share, error) + // GetSamples gets sample for given indices. + GetSamples(ctx context.Context, header *header.ExtendedHeader, indices []shwap.SampleCoords) ([]shwap.Sample, error) // GetEDS gets the full EDS identified by the given extended header. GetEDS(ctx context.Context, height uint64) (*rsmt2d.ExtendedDataSquare, error) // GetNamespaceData gets all shares from an EDS within the given namespace. @@ -65,6 +68,11 @@ type API struct { height uint64, row, col int, ) (libshare.Share, error) `perm:"read"` + GetSamples func( + ctx context.Context, + header *header.ExtendedHeader, + indices []shwap.SampleCoords, + ) ([]shwap.Sample, error) `perm:"read"` GetEDS func( ctx context.Context, height uint64, @@ -90,6 +98,12 @@ func (api *API) GetShare(ctx context.Context, height uint64, row, col int) (libs return api.Internal.GetShare(ctx, height, row, col) } +func (api *API) GetSamples(ctx context.Context, header *header.ExtendedHeader, + indices []shwap.SampleCoords, +) ([]shwap.Sample, error) { + return api.Internal.GetSamples(ctx, header, indices) +} + func (api *API) GetEDS(ctx context.Context, height uint64) (*rsmt2d.ExtendedDataSquare, error) { return api.Internal.GetEDS(ctx, height) } @@ -117,7 +131,21 @@ func (m module) GetShare(ctx context.Context, height uint64, row, col int) (libs if err != nil { return libshare.Share{}, err } - return m.getter.GetShare(ctx, header, row, col) + + idx := shwap.SampleCoords{Row: row, Col: col} + + smpls, err := m.getter.GetSamples(ctx, header, []shwap.SampleCoords{idx}) + if err != nil { + return libshare.Share{}, err + } + + return smpls[0].Share, nil +} + +func (m module) GetSamples(ctx context.Context, header *header.ExtendedHeader, + indices []shwap.SampleCoords, +) ([]shwap.Sample, error) { + return m.getter.GetSamples(ctx, header, indices) } func (m module) GetEDS(ctx context.Context, height uint64) (*rsmt2d.ExtendedDataSquare, error) { diff --git a/share/availability/light/availability.go b/share/availability/light/availability.go index 6a23abe3d9..b33825fe4b 100644 --- a/share/availability/light/availability.go +++ b/share/availability/light/availability.go @@ -114,13 +114,12 @@ func (la *ShareAvailability) SharesAvailable(ctx context.Context, header *header return nil } - var ( - mutex sync.Mutex - failedSamples []Sample - wg sync.WaitGroup - ) + log.Debugw("starting sampling session", "root", dah.String()) - log.Debugw("starting sampling session", "height", header.Height()) + idxs := make([]shwap.SampleCoords, len(samples.Remaining)) + for i, s := range samples.Remaining { + idxs[i] = shwap.SampleCoords{Row: s.Row, Col: s.Col} + } // remove one second from the deadline to ensure we have enough time to process the results samplingCtx, cancel := context.WithCancel(ctx) @@ -129,25 +128,21 @@ func (la *ShareAvailability) SharesAvailable(ctx context.Context, header *header } defer cancel() - // Concurrently sample shares - for _, s := range samples.Remaining { - wg.Add(1) - go func(s Sample) { - defer wg.Done() - _, err := la.getter.GetShare(samplingCtx, header, s.Row, s.Col) - mutex.Lock() - defer mutex.Unlock() - if err != nil { - log.Debugw("error fetching share", "height", header.Height(), "row", s.Row, "col", s.Col) - failedSamples = append(failedSamples, s) - } else { - samples.Available = append(samples.Available, s) - } - }(s) + smpls, errGetSamples := la.getter.GetSamples(samplingCtx, header, idxs) + if len(smpls) == 0 { + return share.ErrNotAvailable + } + + var failedSamples []shwap.SampleCoords + + for i, smpl := range smpls { + if smpl.IsEmpty() { + failedSamples = append(failedSamples, shwap.SampleCoords{Row: idxs[i].Row, Col: idxs[i].Col}) + } else { + samples.Available = append(samples.Available, shwap.SampleCoords{Row: idxs[i].Row, Col: idxs[i].Col}) + } } - wg.Wait() - // Update remaining samples with failed ones samples.Remaining = failedSamples // Store the updated sampling result @@ -162,16 +157,17 @@ func (la *ShareAvailability) SharesAvailable(ctx context.Context, header *header return fmt.Errorf("store sampling result: %w", err) } - if errors.Is(ctx.Err(), context.Canceled) { + if errors.Is(errGetSamples, context.Canceled) { // Availability did not complete due to context cancellation, return context error instead of // share.ErrNotAvailable - return ctx.Err() + return context.Canceled } // if any of the samples failed, return an error if len(failedSamples) > 0 { return share.ErrNotAvailable } + return nil } @@ -210,7 +206,9 @@ func (la *ShareAvailability) Prune(ctx context.Context, h *header.ExtendedHeader // delete stored samples for _, sample := range result.Available { - blk, err := bitswap.NewEmptySampleBlock(h.Height(), sample.Row, sample.Col, len(h.DAH.RowRoots)) + idx := shwap.SampleCoords{Row: sample.Row, Col: sample.Col} + + blk, err := bitswap.NewEmptySampleBlock(h.Height(), idx, len(h.DAH.RowRoots)) if err != nil { return fmt.Errorf("marshal sample ID: %w", err) } diff --git a/share/availability/light/availability_test.go b/share/availability/light/availability_test.go index 64834d46b5..b1a9299bea 100644 --- a/share/availability/light/availability_test.go +++ b/share/availability/light/availability_test.go @@ -4,56 +4,65 @@ import ( "context" _ "embed" "encoding/json" - "errors" + "maps" + "slices" "sync" "sync/atomic" "testing" "time" "github.com/golang/mock/gomock" + "github.com/ipfs/boxo/bitswap/client" "github.com/ipfs/boxo/blockstore" "github.com/ipfs/boxo/exchange" - "github.com/ipfs/boxo/exchange/offline" blocks "github.com/ipfs/go-block-format" "github.com/ipfs/go-cid" "github.com/ipfs/go-datastore" ds_sync "github.com/ipfs/go-datastore/sync" + "github.com/libp2p/go-libp2p/core/host" + mocknet "github.com/libp2p/go-libp2p/p2p/net/mock" "github.com/stretchr/testify/require" libshare "github.com/celestiaorg/go-square/v2/share" + "github.com/celestiaorg/nmt" "github.com/celestiaorg/rsmt2d" "github.com/celestiaorg/celestia-node/header" "github.com/celestiaorg/celestia-node/header/headertest" "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/eds" "github.com/celestiaorg/celestia-node/share/eds/edstest" "github.com/celestiaorg/celestia-node/share/shwap" "github.com/celestiaorg/celestia-node/share/shwap/getters/mock" "github.com/celestiaorg/celestia-node/share/shwap/p2p/bitswap" "github.com/celestiaorg/celestia-node/share/shwap/p2p/shrex" - "github.com/celestiaorg/celestia-node/store" ) func TestSharesAvailableSuccess(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - eds := edstest.RandEDS(t, 16) - roots, err := share.NewAxisRoots(eds) + square := edstest.RandEDS(t, 16) + roots, err := share.NewAxisRoots(square) require.NoError(t, err) eh := headertest.RandExtendedHeaderWithRoot(t, roots) getter := mock.NewMockGetter(gomock.NewController(t)) getter.EXPECT(). - GetShare(gomock.Any(), eh, gomock.Any(), gomock.Any()). + GetSamples(gomock.Any(), eh, gomock.Any()). DoAndReturn( - func(_ context.Context, _ *header.ExtendedHeader, row, col int) (libshare.Share, error) { - rawSh := eds.GetCell(uint(row), uint(col)) - sh, err := libshare.NewShare(rawSh) - if err != nil { - return libshare.Share{}, err + func(_ context.Context, hdr *header.ExtendedHeader, indices []shwap.SampleCoords) ([]shwap.Sample, error) { + acc := eds.Rsmt2D{ExtendedDataSquare: square} + smpls := make([]shwap.Sample, len(indices)) + for i, idx := range indices { + smpl, err := acc.Sample(ctx, idx) + if err != nil { + return nil, err + } + + smpls[i] = smpl } - return *sh, nil + return smpls, nil }). AnyTimes() @@ -87,8 +96,8 @@ func TestSharesAvailableSkipSampled(t *testing.T) { // Create a getter that always returns ErrNotFound getter := mock.NewMockGetter(gomock.NewController(t)) getter.EXPECT(). - GetShare(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(libshare.Share{}, shrex.ErrNotFound). + GetSamples(gomock.Any(), gomock.Any(), gomock.Any()). + Return(nil, shrex.ErrNotFound). AnyTimes() ds := datastore.NewMapDatastore() @@ -104,8 +113,8 @@ func TestSharesAvailableSkipSampled(t *testing.T) { // Store a successful sampling result in the datastore samplingResult := &SamplingResult{ - Available: make([]Sample, avail.params.SampleAmount), - Remaining: []Sample{}, + Available: make([]shwap.SampleCoords, avail.params.SampleAmount), + Remaining: []shwap.SampleCoords{}, } data, err := json.Marshal(samplingResult) require.NoError(t, err) @@ -135,16 +144,6 @@ func TestSharesAvailableFailed(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - failGetter := mock.NewMockGetter(gomock.NewController(t)) - // Getter doesn't have the eds, so it should fail for all samples - failGetter.EXPECT(). - GetShare(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - Return(libshare.Share{}, shrex.ErrNotFound). - AnyTimes() - - ds := datastore.NewMapDatastore() - avail := NewShareAvailability(failGetter, ds, nil) - type test struct { eh *header.ExtendedHeader roots *share.AxisRoots @@ -164,7 +163,16 @@ func TestSharesAvailableFailed(t *testing.T) { } for _, tt := range tests { - avail.getter = failGetter + failGetter := mock.NewMockGetter(gomock.NewController(t)) + ds := datastore.NewMapDatastore() + avail := NewShareAvailability(failGetter, ds, nil) + + // Getter doesn't have the eds, so it should fail for all samples + mockSamples := min(int(avail.params.SampleAmount), 2*len(tt.eh.DAH.RowRoots)) + failGetter.EXPECT(). + GetSamples(gomock.Any(), gomock.Any(), gomock.Any()). + Return(make([]shwap.Sample, mockSamples), shrex.ErrNotFound). + AnyTimes() err := avail.SharesAvailable(ctx, tt.eh) require.ErrorIs(t, err, share.ErrNotAvailable) @@ -185,7 +193,7 @@ func TestSharesAvailableFailed(t *testing.T) { } // Simulate a getter that now returns shares successfully - successfulGetter := newOnceGetter() + successfulGetter := newSuccessGetter() avail.getter = successfulGetter // should be able to retrieve all the failed samples now @@ -208,8 +216,8 @@ func TestSharesAvailableFailed(t *testing.T) { } // onceGetter should have no more samples stored after the call - successfulGetter.checkOnce(t) require.ElementsMatch(t, failed.Remaining, successfulGetter.sampledList()) + successfulGetter.checkOnce(t) } } @@ -219,7 +227,7 @@ func TestParallelAvailability(t *testing.T) { ds := datastore.NewMapDatastore() // Simulate a getter that returns shares successfully - successfulGetter := newOnceGetter() + successfulGetter := newSuccessGetter() avail := NewShareAvailability(successfulGetter, ds, nil) // create new eds, that is not available by getter @@ -229,8 +237,9 @@ func TestParallelAvailability(t *testing.T) { eh := headertest.RandExtendedHeaderWithRoot(t, roots) var wg sync.WaitGroup - for i := 0; i < 100; i++ { - wg.Add(1) + const iters = 100 + wg.Add(iters) + for i := 0; i < iters; i++ { go func() { defer wg.Done() err := avail.SharesAvailable(ctx, eh) @@ -239,6 +248,7 @@ func TestParallelAvailability(t *testing.T) { } wg.Wait() require.Len(t, successfulGetter.sampledList(), int(avail.params.SampleAmount)) + successfulGetter.checkOnce(t) // Verify that the sampling result is stored with all samples marked as available resultData, err := avail.ds.Get(ctx, datastoreKeyForRoot(roots)) @@ -252,19 +262,25 @@ func TestParallelAvailability(t *testing.T) { require.Len(t, samplingResult.Available, int(avail.params.SampleAmount)) } -type onceGetter struct { +type successGetter struct { *sync.Mutex - sampled map[Sample]int + sampled map[shwap.SampleCoords]int } -func newOnceGetter() onceGetter { - return onceGetter{ +func newSuccessGetter() successGetter { + return successGetter{ Mutex: &sync.Mutex{}, - sampled: make(map[Sample]int), + sampled: make(map[shwap.SampleCoords]int), } } -func (g onceGetter) checkOnce(t *testing.T) { +func (g successGetter) sampledList() []shwap.SampleCoords { + g.Lock() + defer g.Unlock() + return slices.Collect(maps.Keys(g.sampled)) +} + +func (g successGetter) checkOnce(t *testing.T) { g.Lock() defer g.Unlock() for s, count := range g.sampled { @@ -274,29 +290,26 @@ func (g onceGetter) checkOnce(t *testing.T) { } } -func (g onceGetter) sampledList() []Sample { +func (g successGetter) GetSamples(_ context.Context, hdr *header.ExtendedHeader, + indices []shwap.SampleCoords, +) ([]shwap.Sample, error) { g.Lock() defer g.Unlock() - samples := make([]Sample, 0, len(g.sampled)) - for s := range g.sampled { - samples = append(samples, s) - } - return samples -} -func (g onceGetter) GetShare(_ context.Context, _ *header.ExtendedHeader, row, col int) (libshare.Share, error) { - g.Lock() - defer g.Unlock() - s := Sample{Row: row, Col: col} - g.sampled[s]++ - return libshare.Share{}, nil + smpls := make([]shwap.Sample, 0, len(indices)) + for _, idx := range indices { + s := shwap.SampleCoords{Row: idx.Row, Col: idx.Col} + g.sampled[s]++ + smpls = append(smpls, shwap.Sample{Proof: &nmt.Proof{}}) + } + return smpls, nil } -func (g onceGetter) GetEDS(_ context.Context, _ *header.ExtendedHeader) (*rsmt2d.ExtendedDataSquare, error) { +func (g successGetter) GetEDS(_ context.Context, _ *header.ExtendedHeader) (*rsmt2d.ExtendedDataSquare, error) { panic("not implemented") } -func (g onceGetter) GetNamespaceData( +func (g successGetter) GetNamespaceData( _ context.Context, _ *header.ExtendedHeader, _ libshare.Namespace, @@ -309,19 +322,11 @@ func TestPruneAll(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) t.Cleanup(cancel) - dir := t.TempDir() - store, err := store.NewStore(store.DefaultParameters(), dir) - require.NoError(t, err) - defer require.NoError(t, store.Stop(ctx)) eds, h := randEdsAndHeader(t, size) - err = store.PutODSQ4(ctx, h.DAH, h.Height(), eds) - require.NoError(t, err) - - // Create a new bitswap getter ds := ds_sync.MutexWrap(datastore.NewMapDatastore()) clientBs := blockstore.NewBlockstore(ds) - serverBS := &bitswap.Blockstore{Getter: store} - ex := newFakeExchange(serverBS) + + ex := newExchangeOverEDS(ctx, t, eds) getter := bitswap.NewGetter(ex, clientBs, 0) getter.Start() defer getter.Stop() @@ -329,7 +334,7 @@ func TestPruneAll(t *testing.T) { // Create a new ShareAvailability instance and sample the shares sampleAmount := uint(20) avail := NewShareAvailability(getter, ds, clientBs, WithSampleAmount(sampleAmount)) - err = avail.SharesAvailable(ctx, h) + err := avail.SharesAvailable(ctx, h) require.NoError(t, err) // close ShareAvailability to force flush of batched writes avail.Close(ctx) @@ -356,19 +361,11 @@ func TestPrunePartialFailed(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) t.Cleanup(cancel) - dir := t.TempDir() - store, err := store.NewStore(store.DefaultParameters(), dir) - require.NoError(t, err) - defer require.NoError(t, store.Stop(ctx)) eds, h := randEdsAndHeader(t, size) - err = store.PutODSQ4(ctx, h.DAH, h.Height(), eds) - require.NoError(t, err) - - // Create a new bitswap getter ds := ds_sync.MutexWrap(datastore.NewMapDatastore()) clientBs := blockstore.NewBlockstore(ds) - serverBS := newHalfFailBlockstore(&bitswap.Blockstore{Getter: store}) - ex := newFakeExchange(serverBS) + + ex := newHalfSessionExchange(newExchangeOverEDS(ctx, t, eds)) getter := bitswap.NewGetter(ex, clientBs, 0) getter.Start() defer getter.Stop() @@ -376,8 +373,8 @@ func TestPrunePartialFailed(t *testing.T) { // Create a new ShareAvailability instance and sample the shares sampleAmount := uint(20) avail := NewShareAvailability(getter, ds, clientBs, WithSampleAmount(sampleAmount)) - err = avail.SharesAvailable(ctx, h) - require.NoError(t, err) + err := avail.SharesAvailable(ctx, h) + require.Error(t, err) // close ShareAvailability to force flush of batched writes avail.Close(ctx) @@ -399,38 +396,116 @@ func TestPrunePartialFailed(t *testing.T) { require.False(t, exist) } -var _ exchange.SessionExchange = (*fakeSessionExchange)(nil) +func TestPruneWithCancelledContext(t *testing.T) { + const size = 8 + ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) + t.Cleanup(cancel) -func newFakeExchange(bs blockstore.Blockstore) *fakeSessionExchange { - return &fakeSessionExchange{ - Interface: offline.Exchange(bs), - session: offline.Exchange(bs), - } + eds, h := randEdsAndHeader(t, size) + ds := ds_sync.MutexWrap(datastore.NewMapDatastore()) + clientBs := blockstore.NewBlockstore(ds) + + ex := newTimeoutExchange(newExchangeOverEDS(ctx, t, eds)) + getter := bitswap.NewGetter(ex, clientBs, 0) + getter.Start() + defer getter.Stop() + + // Create a new ShareAvailability instance and sample the shares + sampleAmount := uint(20) + avail := NewShareAvailability(getter, ds, clientBs, WithSampleAmount(sampleAmount)) + + ctx2, cancel2 := context.WithTimeout(ctx, 1500*time.Millisecond) + defer cancel2() + go func() { + // cancel context a bit later. + time.Sleep(100 * time.Millisecond) + cancel2() + }() + + err := avail.SharesAvailable(ctx2, h) + require.Error(t, err, context.Canceled) + // close ShareAvailability to force flush of batched writes + avail.Close(ctx) + + preDeleteCount := countKeys(ctx, t, clientBs) + require.EqualValues(t, sampleAmount, preDeleteCount) + + // prune the samples + err = avail.Prune(ctx, h) + require.NoError(t, err) + + // Check if samples are deleted + postDeleteCount := countKeys(ctx, t, clientBs) + require.Zero(t, postDeleteCount) + + // Check if sampling result is deleted + exist, err := avail.ds.Has(ctx, datastoreKeyForRoot(h.DAH)) + require.NoError(t, err) + require.False(t, exist) } -type fakeSessionExchange struct { - exchange.Interface - session exchange.Fetcher +type halfSessionExchange struct { + exchange.SessionExchange + attempt atomic.Int32 } -func (fe *fakeSessionExchange) NewSession(context.Context) exchange.Fetcher { - return fe.session +func newHalfSessionExchange(ex exchange.SessionExchange) *halfSessionExchange { + return &halfSessionExchange{SessionExchange: ex} } -type halfFailBlockstore struct { - blockstore.Blockstore - attempt atomic.Int32 +func (hse *halfSessionExchange) NewSession(context.Context) exchange.Fetcher { + return hse +} + +func (hse *halfSessionExchange) GetBlocks(ctx context.Context, cids []cid.Cid) (<-chan blocks.Block, error) { + out := make(chan blocks.Block, len(cids)) + defer close(out) + + for _, cid := range cids { + if hse.attempt.Add(1)%2 == 0 { + continue + } + + blk, err := hse.SessionExchange.GetBlock(ctx, cid) + if err != nil { + return nil, err + } + + out <- blk + } + + return out, nil +} + +type timeoutExchange struct { + exchange.SessionExchange } -func newHalfFailBlockstore(bs blockstore.Blockstore) *halfFailBlockstore { - return &halfFailBlockstore{Blockstore: bs} +func newTimeoutExchange(ex exchange.SessionExchange) *timeoutExchange { + return &timeoutExchange{SessionExchange: ex} } -func (hfb *halfFailBlockstore) Get(ctx context.Context, c cid.Cid) (blocks.Block, error) { - if hfb.attempt.Add(1)%2 == 0 { - return nil, errors.New("fail") +func (hse *timeoutExchange) NewSession(context.Context) exchange.Fetcher { + return hse +} + +func (hse *timeoutExchange) GetBlocks(ctx context.Context, cids []cid.Cid) (<-chan blocks.Block, error) { + out := make(chan blocks.Block, len(cids)) + defer close(out) + + for _, cid := range cids { + blk, err := hse.SessionExchange.GetBlock(ctx, cid) + if err != nil { + return nil, err + } + + out <- blk } - return hfb.Blockstore.Get(ctx, c) + + // sleep guarantees that we context will be canceled in a test. + time.Sleep(200 * time.Millisecond) + + return out, nil } func randEdsAndHeader(t *testing.T, size int) (*rsmt2d.ExtendedDataSquare, *header.ExtendedHeader) { @@ -457,3 +532,55 @@ func countKeys(ctx context.Context, t *testing.T, bs blockstore.Blockstore) int } return count } + +func newExchangeOverEDS(ctx context.Context, t *testing.T, rsmt2d *rsmt2d.ExtendedDataSquare) exchange.SessionExchange { + bstore := &bitswap.Blockstore{ + Getter: testAccessorGetter{ + AccessorStreamer: &eds.Rsmt2D{ExtendedDataSquare: rsmt2d}, + }, + } + return newExchange(ctx, t, bstore) +} + +func newExchange(ctx context.Context, t *testing.T, bstore blockstore.Blockstore) exchange.SessionExchange { + net, err := mocknet.FullMeshLinked(3) + require.NoError(t, err) + + newServer(ctx, net.Hosts()[0], bstore) + newServer(ctx, net.Hosts()[1], bstore) + + client := newClient(ctx, net.Hosts()[2], bstore) + + err = net.ConnectAllButSelf() + require.NoError(t, err) + return client +} + +func newServer(ctx context.Context, host host.Host, store blockstore.Blockstore) { + net := bitswap.NewNetwork(host, "test") + server := bitswap.NewServer( + ctx, + net, + store, + ) + net.Start(server) +} + +func newClient(ctx context.Context, host host.Host, store blockstore.Blockstore) *client.Client { + net := bitswap.NewNetwork(host, "test") + client := bitswap.NewClient(ctx, net, store) + net.Start(client) + return client +} + +type testAccessorGetter struct { + eds.AccessorStreamer +} + +func (t testAccessorGetter) GetByHeight(context.Context, uint64) (eds.AccessorStreamer, error) { + return t.AccessorStreamer, nil +} + +func (t testAccessorGetter) HasByHeight(context.Context, uint64) (bool, error) { + return true, nil +} diff --git a/share/availability/light/sample.go b/share/availability/light/sample.go index 6857ab5365..fc0b41d08a 100644 --- a/share/availability/light/sample.go +++ b/share/availability/light/sample.go @@ -2,21 +2,17 @@ package light import ( crand "crypto/rand" + "maps" "math/big" + "slices" - "golang.org/x/exp/maps" + "github.com/celestiaorg/celestia-node/share/shwap" ) // SamplingResult holds the available and remaining samples. type SamplingResult struct { - Available []Sample `json:"available"` - Remaining []Sample `json:"remaining"` -} - -// Sample represents a coordinate in a 2D data square. -type Sample struct { - Row int `json:"row"` - Col int `json:"col"` + Available []shwap.SampleCoords `json:"available"` + Remaining []shwap.SampleCoords `json:"remaining"` } // NewSamplingResult creates a new SamplingResult with randomly selected samples. @@ -33,25 +29,25 @@ func NewSamplingResult(squareSize, sampleCount int) *SamplingResult { } // selectRandomSamples randomly picks unique coordinates from a square of given size. -func selectRandomSamples(squareSize, sampleCount int) []Sample { +func selectRandomSamples(squareSize, sampleCount int) []shwap.SampleCoords { total := squareSize * squareSize if sampleCount > total { sampleCount = total } - samples := make(map[Sample]struct{}, sampleCount) + samples := make(map[shwap.SampleCoords]struct{}, sampleCount) for len(samples) < sampleCount { - s := Sample{ + s := shwap.SampleCoords{ Row: randInt(squareSize), Col: randInt(squareSize), } samples[s] = struct{}{} } - return maps.Keys(samples) + return slices.Collect(maps.Keys(samples)) } -func randInt(max int) int { - n, err := crand.Int(crand.Reader, big.NewInt(int64(max))) +func randInt(m int) int { + n, err := crand.Int(crand.Reader, big.NewInt(int64(m))) if err != nil { panic(err) // won't panic as rand.Reader is endless } diff --git a/share/eds/accessor.go b/share/eds/accessor.go index 07eb6db542..eb3fea6e41 100644 --- a/share/eds/accessor.go +++ b/share/eds/accessor.go @@ -25,7 +25,7 @@ type Accessor interface { // Sample returns share and corresponding proof for row and column indices. Implementation can // choose which axis to use for proof. Chosen axis for proof should be indicated in the returned // Sample. - Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sample, error) + Sample(ctx context.Context, idx shwap.SampleCoords) (shwap.Sample, error) // AxisHalf returns half of shares axis of the given type and index. Side is determined by // implementation. Implementations should indicate the side in the returned AxisHalf. AxisHalf(ctx context.Context, axisType rsmt2d.Axis, axisIdx int) (AxisHalf, error) diff --git a/share/eds/close_once.go b/share/eds/close_once.go index cc217710ce..579b4b1e20 100644 --- a/share/eds/close_once.go +++ b/share/eds/close_once.go @@ -57,11 +57,11 @@ func (c *closeOnce) AxisRoots(ctx context.Context) (*share.AxisRoots, error) { return c.f.AxisRoots(ctx) } -func (c *closeOnce) Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sample, error) { +func (c *closeOnce) Sample(ctx context.Context, idx shwap.SampleCoords) (shwap.Sample, error) { if c.closed.Load() { return shwap.Sample{}, errAccessorClosed } - return c.f.Sample(ctx, rowIdx, colIdx) + return c.f.Sample(ctx, idx) } func (c *closeOnce) AxisHalf( diff --git a/share/eds/close_once_test.go b/share/eds/close_once_test.go index d515ac7bda..e34299f805 100644 --- a/share/eds/close_once_test.go +++ b/share/eds/close_once_test.go @@ -20,7 +20,7 @@ func TestWithClosedOnce(t *testing.T) { stub := &stubEdsAccessorCloser{} closedOnce := WithClosedOnce(stub) - _, err := closedOnce.Sample(ctx, 0, 0) + _, err := closedOnce.Sample(ctx, shwap.SampleCoords{}) require.NoError(t, err) _, err = closedOnce.AxisHalf(ctx, rsmt2d.Row, 0) require.NoError(t, err) @@ -33,7 +33,7 @@ func TestWithClosedOnce(t *testing.T) { require.True(t, stub.closed) // Ensure that the underlying file is not accessible after closing - _, err = closedOnce.Sample(ctx, 0, 0) + _, err = closedOnce.Sample(ctx, shwap.SampleCoords{}) require.ErrorIs(t, err, errAccessorClosed) _, err = closedOnce.AxisHalf(ctx, rsmt2d.Row, 0) require.ErrorIs(t, err, errAccessorClosed) @@ -59,7 +59,7 @@ func (s *stubEdsAccessorCloser) AxisRoots(context.Context) (*share.AxisRoots, er return &share.AxisRoots{}, nil } -func (s *stubEdsAccessorCloser) Sample(context.Context, int, int) (shwap.Sample, error) { +func (s *stubEdsAccessorCloser) Sample(context.Context, shwap.SampleCoords) (shwap.Sample, error) { return shwap.Sample{}, nil } diff --git a/share/eds/proofs_cache.go b/share/eds/proofs_cache.go index e777b82962..f73ebcdf80 100644 --- a/share/eds/proofs_cache.go +++ b/share/eds/proofs_cache.go @@ -112,8 +112,8 @@ func (c *proofsCache) AxisRoots(ctx context.Context) (*share.AxisRoots, error) { return roots, nil } -func (c *proofsCache) Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sample, error) { - axisType, axisIdx, shrIdx := rsmt2d.Row, rowIdx, colIdx +func (c *proofsCache) Sample(ctx context.Context, idx shwap.SampleCoords) (shwap.Sample, error) { + axisType, axisIdx, shrIdx := rsmt2d.Row, idx.Row, idx.Col ax, err := c.axisWithProofs(ctx, axisType, axisIdx) if err != nil { return shwap.Sample{}, err diff --git a/share/eds/rsmt2d.go b/share/eds/rsmt2d.go index e0e945fccb..6c244de700 100644 --- a/share/eds/rsmt2d.go +++ b/share/eds/rsmt2d.go @@ -46,18 +46,18 @@ func (eds *Rsmt2D) AxisRoots(context.Context) (*share.AxisRoots, error) { // Sample returns share and corresponding proof for row and column indices. func (eds *Rsmt2D) Sample( _ context.Context, - rowIdx, colIdx int, + idx shwap.SampleCoords, ) (shwap.Sample, error) { - return eds.SampleForProofAxis(rowIdx, colIdx, rsmt2d.Row) + return eds.SampleForProofAxis(idx, rsmt2d.Row) } // SampleForProofAxis samples a share from an Extended Data Square based on the provided // row and column indices and proof axis. It returns a sample with the share and proof. func (eds *Rsmt2D) SampleForProofAxis( - rowIdx, colIdx int, + idx shwap.SampleCoords, proofType rsmt2d.Axis, ) (shwap.Sample, error) { - axisIdx, shrIdx := relativeIndexes(rowIdx, colIdx, proofType) + axisIdx, shrIdx := relativeIndexes(idx.Row, idx.Col, proofType) shares, err := getAxis(eds.ExtendedDataSquare, proofType, axisIdx) if err != nil { return shwap.Sample{}, err diff --git a/share/eds/rsmt2d_test.go b/share/eds/rsmt2d_test.go index 96bde8c2ab..89fd316717 100644 --- a/share/eds/rsmt2d_test.go +++ b/share/eds/rsmt2d_test.go @@ -56,7 +56,9 @@ func TestRsmt2dSampleForProofAxis(t *testing.T) { for _, proofType := range []rsmt2d.Axis{rsmt2d.Row, rsmt2d.Col} { for rowIdx := 0; rowIdx < odsSize*2; rowIdx++ { for colIdx := 0; colIdx < odsSize*2; colIdx++ { - sample, err := accessor.SampleForProofAxis(rowIdx, colIdx, proofType) + idx := shwap.SampleCoords{Row: rowIdx, Col: colIdx} + + sample, err := accessor.SampleForProofAxis(idx, proofType) require.NoError(t, err) want := eds.GetCell(uint(rowIdx), uint(colIdx)) diff --git a/share/eds/testing.go b/share/eds/testing.go index 388544bc00..c8859dca5a 100644 --- a/share/eds/testing.go +++ b/share/eds/testing.go @@ -148,7 +148,8 @@ func testAccessorSample( // t.Parallel() this fails the test for some reason for rowIdx := 0; rowIdx < width; rowIdx++ { for colIdx := 0; colIdx < width; colIdx++ { - testSample(ctx, t, acc, roots, colIdx, rowIdx) + idx := shwap.SampleCoords{Row: rowIdx, Col: colIdx} + testSample(ctx, t, acc, roots, idx) } } }) @@ -162,10 +163,11 @@ func testAccessorSample( for rowIdx := 0; rowIdx < width; rowIdx++ { for colIdx := 0; colIdx < width; colIdx++ { wg.Add(1) - go func(rowIdx, colIdx int) { + idx := shwap.SampleCoords{Row: rowIdx, Col: colIdx} + go func(idx shwap.SampleCoords) { defer wg.Done() - testSample(ctx, t, acc, roots, rowIdx, colIdx) - }(rowIdx, colIdx) + testSample(ctx, t, acc, roots, idx) + }(idx) } } wg.Wait() @@ -182,8 +184,9 @@ func testAccessorSample( wg.Add(1) go func() { defer wg.Done() - rowIdx, colIdx := rand.IntN(width), rand.IntN(width) //nolint:gosec - testSample(ctx, t, acc, roots, rowIdx, colIdx) + rowIdx := rand.IntN(int(eds.Width())) //nolint:gosec + colIdx := rand.IntN(int(eds.Width())) //nolint:gosec + testSample(ctx, t, acc, roots, shwap.SampleCoords{Row: rowIdx, Col: colIdx}) }() } wg.Wait() @@ -195,12 +198,12 @@ func testSample( t *testing.T, acc Accessor, roots *share.AxisRoots, - rowIdx, colIdx int, + idx shwap.SampleCoords, ) { - shr, err := acc.Sample(ctx, rowIdx, colIdx) + shr, err := acc.Sample(ctx, idx) require.NoError(t, err) - err = shr.Verify(roots, rowIdx, colIdx) + err = shr.Verify(roots, idx.Row, idx.Col) require.NoError(t, err) } @@ -444,13 +447,15 @@ func BenchGetSampleFromAccessor( name := fmt.Sprintf("Size:%v/quadrant:%s", size, q) b.Run(name, func(b *testing.B) { rowIdx, colIdx := q.coordinates(acc.Size(ctx)) + idx := shwap.SampleCoords{Row: rowIdx, Col: colIdx} + // warm up cache - _, err := acc.Sample(ctx, rowIdx, colIdx) + _, err := acc.Sample(ctx, idx) require.NoError(b, err, q.String()) b.ResetTimer() for i := 0; i < b.N; i++ { - _, err := acc.Sample(ctx, rowIdx, colIdx) + _, err := acc.Sample(ctx, idx) require.NoError(b, err) } }) diff --git a/share/eds/validation.go b/share/eds/validation.go index 845a5bac77..4f6cf0aa85 100644 --- a/share/eds/validation.go +++ b/share/eds/validation.go @@ -34,12 +34,12 @@ func (f validation) Size(ctx context.Context) int { return int(size) } -func (f validation) Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sample, error) { - _, err := shwap.NewSampleID(1, rowIdx, colIdx, f.Size(ctx)) +func (f validation) Sample(ctx context.Context, idx shwap.SampleCoords) (shwap.Sample, error) { + _, err := shwap.NewSampleID(1, idx, f.Size(ctx)) if err != nil { return shwap.Sample{}, fmt.Errorf("sample validation: %w", err) } - return f.Accessor.Sample(ctx, rowIdx, colIdx) + return f.Accessor.Sample(ctx, idx) } func (f validation) AxisHalf(ctx context.Context, axisType rsmt2d.Axis, axisIdx int) (AxisHalf, error) { diff --git a/share/eds/validation_test.go b/share/eds/validation_test.go index 3e645cbfb3..9ec6b3fdb3 100644 --- a/share/eds/validation_test.go +++ b/share/eds/validation_test.go @@ -34,7 +34,9 @@ func TestValidation_Sample(t *testing.T) { accessor := &Rsmt2D{ExtendedDataSquare: randEDS} validation := WithValidation(AccessorAndStreamer(accessor, nil)) - _, err := validation.Sample(context.Background(), tt.rowIdx, tt.colIdx) + idx := shwap.SampleCoords{Row: tt.rowIdx, Col: tt.colIdx} + + _, err := validation.Sample(context.Background(), idx) if tt.expectFail { require.ErrorIs(t, err, shwap.ErrInvalidID) } else { diff --git a/share/shwap/getter.go b/share/shwap/getter.go index 21ac2ec49f..9e0a5d3131 100644 --- a/share/shwap/getter.go +++ b/share/shwap/getter.go @@ -22,6 +22,8 @@ var ( // ErrOutOfBounds is used to indicate that a passed row or column index is out of bounds of the // square size. ErrOutOfBounds = fmt.Errorf("index out of bounds: %w", ErrInvalidID) + // ErrNoSampleIndicies is used to indicate that no indicies where given to process. + ErrNoSampleIndicies = errors.New("no sample indicies to fetch") ) // Getter interface provides a set of accessors for shares by the Root. @@ -29,8 +31,10 @@ var ( // //go:generate mockgen -destination=getters/mock/getter.go -package=mock . Getter type Getter interface { - // GetShare gets a Share by coordinates in EDS. - GetShare(ctx context.Context, header *header.ExtendedHeader, row, col int) (libshare.Share, error) + // GetSamples gets samples by their indices. + // Returns Sample slice with requested number of samples in the requested order. + // May return partial response with some samples being empty if they weren't found. + GetSamples(ctx context.Context, header *header.ExtendedHeader, indices []SampleCoords) ([]Sample, error) // GetEDS gets the full EDS identified by the given extended header. GetEDS(context.Context, *header.ExtendedHeader) (*rsmt2d.ExtendedDataSquare, error) diff --git a/share/shwap/getters/cascade.go b/share/shwap/getters/cascade.go index 962adffd3c..39ceb2fdb1 100644 --- a/share/shwap/getters/cascade.go +++ b/share/shwap/getters/cascade.go @@ -41,24 +41,17 @@ func NewCascadeGetter(getters []shwap.Getter) *CascadeGetter { } } -// GetShare gets a share from any of registered shwap.Getters in cascading order. -func (cg *CascadeGetter) GetShare( - ctx context.Context, header *header.ExtendedHeader, row, col int, -) (libshare.Share, error) { - ctx, span := tracer.Start(ctx, "cascade/get-share", trace.WithAttributes( - attribute.Int("row", row), - attribute.Int("col", col), +// GetSamples gets samples from any of registered shwap.Getters in cascading order. +func (cg *CascadeGetter) GetSamples(ctx context.Context, hdr *header.ExtendedHeader, + indices []shwap.SampleCoords, +) ([]shwap.Sample, error) { + ctx, span := tracer.Start(ctx, "cascade/get-samples", trace.WithAttributes( + attribute.Int("amount", len(indices)), )) defer span.End() - upperBound := len(header.DAH.RowRoots) - if row >= upperBound || col >= upperBound { - err := shwap.ErrOutOfBounds - span.RecordError(err) - return libshare.Share{}, err - } - get := func(ctx context.Context, get shwap.Getter) (libshare.Share, error) { - return get.GetShare(ctx, header, row, col) + get := func(ctx context.Context, get shwap.Getter) ([]shwap.Sample, error) { + return get.GetSamples(ctx, hdr, indices) } return cascadeGetters(ctx, cg.getters, get) diff --git a/share/shwap/getters/cascade_test.go b/share/shwap/getters/cascade_test.go index a23568006f..8237d8ca5c 100644 --- a/share/shwap/getters/cascade_test.go +++ b/share/shwap/getters/cascade_test.go @@ -30,7 +30,7 @@ func TestCascadeGetter(t *testing.T) { getter := NewCascadeGetter(getters) t.Run("GetShare", func(t *testing.T) { for _, eh := range headers { - sh, err := getter.GetShare(ctx, eh, 0, 0) + sh, err := getter.GetSamples(ctx, eh, []shwap.SampleCoords{{}}) assert.NoError(t, err) assert.NotEmpty(t, sh) } diff --git a/share/shwap/getters/mock/getter.go b/share/shwap/getters/mock/getter.go index 856802d75d..7e4dacb24a 100644 --- a/share/shwap/getters/mock/getter.go +++ b/share/shwap/getters/mock/getter.go @@ -68,17 +68,17 @@ func (mr *MockGetterMockRecorder) GetNamespaceData(arg0, arg1, arg2 interface{}) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNamespaceData", reflect.TypeOf((*MockGetter)(nil).GetNamespaceData), arg0, arg1, arg2) } -// GetShare mocks base method. -func (m *MockGetter) GetShare(arg0 context.Context, arg1 *header.ExtendedHeader, arg2, arg3 int) (share.Share, error) { +// GetSamples mocks base method. +func (m *MockGetter) GetSamples(arg0 context.Context, arg1 *header.ExtendedHeader, arg2 []shwap.SampleCoords) ([]shwap.Sample, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetShare", arg0, arg1, arg2, arg3) - ret0, _ := ret[0].(share.Share) + ret := m.ctrl.Call(m, "GetSamples", arg0, arg1, arg2) + ret0, _ := ret[0].([]shwap.Sample) ret1, _ := ret[1].(error) return ret0, ret1 } -// GetShare indicates an expected call of GetShare. -func (mr *MockGetterMockRecorder) GetShare(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +// GetSamples indicates an expected call of GetSamples. +func (mr *MockGetterMockRecorder) GetSamples(arg0, arg1, arg2 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetShare", reflect.TypeOf((*MockGetter)(nil).GetShare), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetSamples", reflect.TypeOf((*MockGetter)(nil).GetSamples), arg0, arg1, arg2) } diff --git a/share/shwap/getters/testing.go b/share/shwap/getters/testing.go index a8fdd53ee6..a3ee53753d 100644 --- a/share/shwap/getters/testing.go +++ b/share/shwap/getters/testing.go @@ -14,43 +14,48 @@ import ( "github.com/celestiaorg/celestia-node/header" "github.com/celestiaorg/celestia-node/header/headertest" "github.com/celestiaorg/celestia-node/share" + "github.com/celestiaorg/celestia-node/share/eds" "github.com/celestiaorg/celestia-node/share/eds/edstest" "github.com/celestiaorg/celestia-node/share/shwap" ) // TestGetter provides a testing SingleEDSGetter and the root of the EDS it holds. func TestGetter(t *testing.T) (shwap.Getter, *header.ExtendedHeader) { - eds := edstest.RandEDS(t, 8) - roots, err := share.NewAxisRoots(eds) + square := edstest.RandEDS(t, 8) + roots, err := share.NewAxisRoots(square) eh := headertest.RandExtendedHeaderWithRoot(t, roots) require.NoError(t, err) return &SingleEDSGetter{ - EDS: eds, + EDS: eds.Rsmt2D{ExtendedDataSquare: square}, }, eh } // SingleEDSGetter contains a single EDS where data is retrieved from. // Its primary use is testing, and GetNamespaceData is not supported. type SingleEDSGetter struct { - EDS *rsmt2d.ExtendedDataSquare + EDS eds.Rsmt2D } -// GetShare gets a share from a kept EDS if exist and if the correct root is given. -func (seg *SingleEDSGetter) GetShare( - _ context.Context, - header *header.ExtendedHeader, - row, col int, -) (libshare.Share, error) { - err := seg.checkRoots(header.DAH) +// GetSamples get samples from a kept EDS if exist and if the correct root is given. +func (seg *SingleEDSGetter) GetSamples(ctx context.Context, hdr *header.ExtendedHeader, + indices []shwap.SampleCoords, +) ([]shwap.Sample, error) { + err := seg.checkRoots(hdr.DAH) if err != nil { - return libshare.Share{}, err + return nil, err } - rawSh := seg.EDS.GetCell(uint(row), uint(col)) - sh, err := libshare.NewShare(rawSh) - if err != nil { - return libshare.Share{}, err + + smpls := make([]shwap.Sample, len(indices)) + for i, idx := range indices { + smpl, err := seg.EDS.Sample(ctx, idx) + if err != nil { + return nil, err + } + + smpls[i] = smpl } - return *sh, nil + + return smpls, nil } // GetEDS returns a kept EDS if the correct root is given. @@ -62,7 +67,7 @@ func (seg *SingleEDSGetter) GetEDS( if err != nil { return nil, err } - return seg.EDS, nil + return seg.EDS.ExtendedDataSquare, nil } // GetNamespaceData returns NamespacedShares from a kept EDS if the correct root is given. @@ -72,7 +77,7 @@ func (seg *SingleEDSGetter) GetNamespaceData(context.Context, *header.ExtendedHe } func (seg *SingleEDSGetter) checkRoots(roots *share.AxisRoots) error { - dah, err := da.NewDataAvailabilityHeader(seg.EDS) + dah, err := da.NewDataAvailabilityHeader(seg.EDS.ExtendedDataSquare) if err != nil { return err } diff --git a/share/shwap/p2p/bitswap/getter.go b/share/shwap/p2p/bitswap/getter.go index 308185a36c..db2938663c 100644 --- a/share/shwap/p2p/bitswap/getter.go +++ b/share/shwap/p2p/bitswap/getter.go @@ -11,6 +11,7 @@ import ( "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" "github.com/celestiaorg/celestia-app/v3/pkg/wrapper" libshare "github.com/celestiaorg/go-square/v2/share" @@ -74,27 +75,24 @@ func (g *Getter) Stop() { g.cancel() } -// GetShares uses [SampleBlock] and [Fetch] to get and verify samples for given coordinates. -// TODO(@Wondertan): Rework API to get coordinates as a single param to make it ergonomic. -func (g *Getter) GetShares( +// GetSamples uses [SampleBlock] and [Fetch] to get and verify samples for given coordinates. +func (g *Getter) GetSamples( ctx context.Context, hdr *header.ExtendedHeader, - rowIdxs, colIdxs []int, -) ([]libshare.Share, error) { - if len(rowIdxs) != len(colIdxs) { - return nil, fmt.Errorf("row indecies and col indices must be same length") + indices []shwap.SampleCoords, +) ([]shwap.Sample, error) { + if len(indices) == 0 { + return nil, shwap.ErrNoSampleIndicies } - if len(rowIdxs) == 0 { - return nil, fmt.Errorf("empty coordinates") - } - - ctx, span := tracer.Start(ctx, "get-shares") + ctx, span := tracer.Start(ctx, "get-samples", trace.WithAttributes( + attribute.Int("amount", len(indices)), + )) defer span.End() - blks := make([]Block, len(rowIdxs)) - for i, rowIdx := range rowIdxs { - sid, err := NewEmptySampleBlock(hdr.Height(), rowIdx, colIdxs[i], len(hdr.DAH.RowRoots)) + blks := make([]Block, len(indices)) + for i, idx := range indices { + sid, err := NewEmptySampleBlock(hdr.Height(), idx, len(hdr.DAH.RowRoots)) if err != nil { span.RecordError(err) span.SetStatus(codes.Error, "NewEmptySampleBlock") @@ -111,37 +109,29 @@ func (g *Getter) GetShares( defer release() err := Fetch(ctx, g.exchange, hdr.DAH, blks, WithStore(g.bstore), WithFetcher(ses)) - if err != nil { - span.RecordError(err) - span.SetStatus(codes.Error, "Fetch") - return nil, err - } - shares := make([]libshare.Share, len(blks)) + var fetched int + smpls := make([]shwap.Sample, len(blks)) for i, blk := range blks { - shares[i] = blk.(*SampleBlock).Container.Share + c := blk.(*SampleBlock).Container + if !c.IsEmpty() { + fetched++ + smpls[i] = c + } } - span.SetStatus(codes.Ok, "") - return shares, nil -} - -// GetShare uses [GetShare] to fetch and verify single share by the given coordinates. -func (g *Getter) GetShare( - ctx context.Context, - hdr *header.ExtendedHeader, - row, col int, -) (libshare.Share, error) { - shrs, err := g.GetShares(ctx, hdr, []int{row}, []int{col}) if err != nil { - return libshare.Share{}, err - } - - if len(shrs) != 1 { - return libshare.Share{}, fmt.Errorf("expected 1 share row, got %d", len(shrs)) + span.RecordError(err) + span.SetStatus(codes.Error, "Fetch") + if fetched > 0 { + span.SetAttributes(attribute.Int("fetched", fetched)) + return smpls, err + } + return nil, err } - return shrs[0], nil + span.SetStatus(codes.Ok, "") + return smpls, nil } // GetEDS uses [RowBlock] and [Fetch] to get half of the first EDS quadrant(ODS) and diff --git a/share/shwap/p2p/bitswap/sample_block.go b/share/shwap/p2p/bitswap/sample_block.go index 0f25ccb454..8d22047330 100644 --- a/share/shwap/p2p/bitswap/sample_block.go +++ b/share/shwap/p2p/bitswap/sample_block.go @@ -47,8 +47,8 @@ type SampleBlock struct { } // NewEmptySampleBlock constructs a new empty SampleBlock. -func NewEmptySampleBlock(height uint64, rowIdx, colIdx, edsSize int) (*SampleBlock, error) { - id, err := shwap.NewSampleID(height, rowIdx, colIdx, edsSize) +func NewEmptySampleBlock(height uint64, idx shwap.SampleCoords, edsSize int) (*SampleBlock, error) { + id, err := shwap.NewSampleID(height, idx, edsSize) if err != nil { return nil, err } @@ -94,7 +94,9 @@ func (sb *SampleBlock) Marshal() ([]byte, error) { } func (sb *SampleBlock) Populate(ctx context.Context, eds eds.Accessor) error { - smpl, err := eds.Sample(ctx, sb.ID.RowIndex, sb.ID.ShareIndex) + idx := shwap.SampleCoords{Row: sb.ID.RowIndex, Col: sb.ID.ShareIndex} + + smpl, err := eds.Sample(ctx, idx) if err != nil { return fmt.Errorf("accessing Sample: %w", err) } diff --git a/share/shwap/p2p/bitswap/sample_block_test.go b/share/shwap/p2p/bitswap/sample_block_test.go index 2a28e7e4c9..3339f24314 100644 --- a/share/shwap/p2p/bitswap/sample_block_test.go +++ b/share/shwap/p2p/bitswap/sample_block_test.go @@ -9,6 +9,7 @@ import ( "github.com/celestiaorg/celestia-node/share" "github.com/celestiaorg/celestia-node/share/eds/edstest" + "github.com/celestiaorg/celestia-node/share/shwap" ) func TestSample_FetchRoundtrip(t *testing.T) { @@ -24,7 +25,9 @@ func TestSample_FetchRoundtrip(t *testing.T) { blks := make([]Block, 0, width*width) for x := 0; x < width; x++ { for y := 0; y < width; y++ { - blk, err := NewEmptySampleBlock(1, x, y, len(root.RowRoots)) + idx := shwap.SampleCoords{Row: x, Col: y} + + blk, err := NewEmptySampleBlock(1, idx, len(root.RowRoots)) require.NoError(t, err) blks = append(blks, blk) } diff --git a/share/shwap/p2p/shrex/shrex_getter/shrex.go b/share/shwap/p2p/shrex/shrex_getter/shrex.go index f51328136c..6c91a44736 100644 --- a/share/shwap/p2p/shrex/shrex_getter/shrex.go +++ b/share/shwap/p2p/shrex/shrex_getter/shrex.go @@ -146,8 +146,8 @@ func (sg *Getter) Stop(ctx context.Context) error { return sg.archivalPeerManager.Stop(ctx) } -func (sg *Getter) GetShare(context.Context, *header.ExtendedHeader, int, int) (libshare.Share, error) { - return libshare.Share{}, fmt.Errorf("getter/shrex: GetShare %w", shwap.ErrOperationNotSupported) +func (sg *Getter) GetSamples(context.Context, *header.ExtendedHeader, []shwap.SampleCoords) ([]shwap.Sample, error) { + return nil, fmt.Errorf("getter/shrex: GetShare %w", shwap.ErrOperationNotSupported) } func (sg *Getter) GetEDS(ctx context.Context, header *header.ExtendedHeader) (*rsmt2d.ExtendedDataSquare, error) { diff --git a/share/shwap/sample.go b/share/shwap/sample.go index ab07d2f5de..9b8e16a93d 100644 --- a/share/shwap/sample.go +++ b/share/shwap/sample.go @@ -31,8 +31,8 @@ type Sample struct { // SampleFromShares creates a Sample from a list of shares, using the specified proof type and // the share index to be included in the sample. -func SampleFromShares(shares []libshare.Share, proofType rsmt2d.Axis, axisIdx, shrIdx int) (Sample, error) { - tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(len(shares)/2), uint(axisIdx)) +func SampleFromShares(shares []libshare.Share, proofType rsmt2d.Axis, idx SampleCoords) (Sample, error) { + tree := wrapper.NewErasuredNamespacedMerkleTree(uint64(len(shares)/2), uint(idx.Row)) for _, shr := range shares { err := tree.Push(shr.ToBytes()) if err != nil { @@ -40,13 +40,13 @@ func SampleFromShares(shares []libshare.Share, proofType rsmt2d.Axis, axisIdx, s } } - proof, err := tree.ProveRange(shrIdx, shrIdx+1) + proof, err := tree.ProveRange(idx.Col, idx.Col+1) if err != nil { return Sample{}, err } return Sample{ - Share: shares[shrIdx], + Share: shares[idx.Col], Proof: &proof, ProofType: proofType, }, nil diff --git a/share/shwap/sample_id.go b/share/shwap/sample_id.go index b03bbacfda..a3b13279fd 100644 --- a/share/shwap/sample_id.go +++ b/share/shwap/sample_id.go @@ -10,6 +10,31 @@ import ( // bytes for the ShareIndex. const SampleIDSize = RowIDSize + 2 +type SampleCoords struct { + Row int `json:"row"` + Col int `json:"col"` +} + +func SampleCoordsAs1DIndex(idx SampleCoords, edsSize int) (int, error) { + if idx.Row < 0 || idx.Col < 0 { + return 0, fmt.Errorf("negative row or col index: %w", ErrInvalidID) + } + if idx.Row >= edsSize || idx.Col >= edsSize { + return 0, fmt.Errorf("SampleCoords %d || %d > %d: %w", idx.Row, idx.Col, edsSize, ErrOutOfBounds) + } + return idx.Row*edsSize + idx.Col, nil +} + +func SampleCoordsFrom1DIndex(idx, squareSize int) (SampleCoords, error) { + if idx > squareSize*squareSize { + return SampleCoords{}, fmt.Errorf("SampleCoords %d > %d: %w", idx, squareSize*squareSize, ErrOutOfBounds) + } + + rowIdx := idx / squareSize + colIdx := idx % squareSize + return SampleCoords{Row: rowIdx, Col: colIdx}, nil +} + // SampleID uniquely identifies a specific sample within a row of an Extended Data Square (EDS). type SampleID struct { RowID // Embeds RowID to incorporate block height and row index. @@ -18,15 +43,15 @@ type SampleID struct { // NewSampleID constructs a new SampleID using the provided block height, sample index, and EDS // size. It calculates the row and share index based on the sample index and EDS size. -func NewSampleID(height uint64, rowIdx, colIdx, edsSize int) (SampleID, error) { +func NewSampleID(height uint64, idx SampleCoords, edsSize int) (SampleID, error) { sid := SampleID{ RowID: RowID{ EdsID: EdsID{ Height: height, }, - RowIndex: rowIdx, + RowIndex: idx.Row, }, - ShareIndex: colIdx, + ShareIndex: idx.Col, } if err := sid.Verify(edsSize); err != nil { diff --git a/share/shwap/sample_id_test.go b/share/shwap/sample_id_test.go index 125d536854..3df2f6ce56 100644 --- a/share/shwap/sample_id_test.go +++ b/share/shwap/sample_id_test.go @@ -11,7 +11,7 @@ import ( func TestSampleID(t *testing.T) { edsSize := 4 - id, err := NewSampleID(1, 1, 1, edsSize) + id, err := NewSampleID(1, SampleCoords{Col: 1}, edsSize) require.NoError(t, err) data, err := id.MarshalBinary() @@ -29,7 +29,7 @@ func TestSampleID(t *testing.T) { func TestSampleIDReaderWriter(t *testing.T) { edsSize := 4 - id, err := NewSampleID(1, 1, 1, edsSize) + id, err := NewSampleID(1, SampleCoords{Col: 1}, edsSize) require.NoError(t, err) buf := bytes.NewBuffer(nil) @@ -44,3 +44,15 @@ func TestSampleIDReaderWriter(t *testing.T) { require.EqualValues(t, id, sidOut) } + +func TestSampleCoords(t *testing.T) { + edsSize := 16 + + rawIdx := 13 * 16 + idxIn, err := SampleCoordsFrom1DIndex(rawIdx, edsSize) + require.NoError(t, err) + + idxOut, err := SampleCoordsAs1DIndex(idxIn, edsSize) + require.NoError(t, err) + assert.Equal(t, rawIdx, idxOut) +} diff --git a/share/shwap/sample_test.go b/share/shwap/sample_test.go index 030eeb4677..0c03eb47cc 100644 --- a/share/shwap/sample_test.go +++ b/share/shwap/sample_test.go @@ -25,10 +25,12 @@ func TestSampleValidate(t *testing.T) { for _, proofType := range []rsmt2d.Axis{rsmt2d.Row, rsmt2d.Col} { for rowIdx := 0; rowIdx < odsSize*2; rowIdx++ { for colIdx := 0; colIdx < odsSize*2; colIdx++ { - sample, err := inMem.SampleForProofAxis(rowIdx, colIdx, proofType) + idx := shwap.SampleCoords{Row: rowIdx, Col: colIdx} + + sample, err := inMem.SampleForProofAxis(idx, proofType) require.NoError(t, err) - require.NoError(t, sample.Verify(root, rowIdx, colIdx)) + require.NoError(t, sample.Verify(root, rowIdx, colIdx), "row: %d col: %d", rowIdx, colIdx) } } } @@ -42,7 +44,7 @@ func TestSampleNegativeVerifyInclusion(t *testing.T) { require.NoError(t, err) inMem := eds.Rsmt2D{ExtendedDataSquare: randEDS} - sample, err := inMem.Sample(context.Background(), 0, 0) + sample, err := inMem.Sample(context.Background(), shwap.SampleCoords{}) require.NoError(t, err) err = sample.Verify(root, 0, 0) require.NoError(t, err) @@ -61,14 +63,14 @@ func TestSampleNegativeVerifyInclusion(t *testing.T) { require.ErrorIs(t, err, shwap.ErrFailedVerification) // incorrect proofType - sample, err = inMem.Sample(context.Background(), 0, 0) + sample, err = inMem.Sample(context.Background(), shwap.SampleCoords{}) require.NoError(t, err) sample.ProofType = rsmt2d.Col err = sample.Verify(root, 0, 0) require.ErrorIs(t, err, shwap.ErrFailedVerification) // Corrupt the last root hash byte - sample, err = inMem.Sample(context.Background(), 0, 0) + sample, err = inMem.Sample(context.Background(), shwap.SampleCoords{}) require.NoError(t, err) root.RowRoots[0][len(root.RowRoots[0])-1] ^= 0xFF err = sample.Verify(root, 0, 0) @@ -83,7 +85,9 @@ func TestSampleProtoEncoding(t *testing.T) { for _, proofType := range []rsmt2d.Axis{rsmt2d.Row, rsmt2d.Col} { for rowIdx := 0; rowIdx < odsSize*2; rowIdx++ { for colIdx := 0; colIdx < odsSize*2; colIdx++ { - sample, err := inMem.SampleForProofAxis(rowIdx, colIdx, proofType) + idx := shwap.SampleCoords{Row: rowIdx, Col: colIdx} + + sample, err := inMem.SampleForProofAxis(idx, proofType) require.NoError(t, err) pb := sample.ToProto() @@ -103,7 +107,8 @@ func BenchmarkSampleValidate(b *testing.B) { root, err := share.NewAxisRoots(randEDS) require.NoError(b, err) inMem := eds.Rsmt2D{ExtendedDataSquare: randEDS} - sample, err := inMem.SampleForProofAxis(0, 0, rsmt2d.Row) + + sample, err := inMem.SampleForProofAxis(shwap.SampleCoords{}, rsmt2d.Row) require.NoError(b, err) b.ResetTimer() diff --git a/store/cache/accessor_cache_test.go b/store/cache/accessor_cache_test.go index 9c7104fbe7..8b537049e1 100644 --- a/store/cache/accessor_cache_test.go +++ b/store/cache/accessor_cache_test.go @@ -315,7 +315,7 @@ func (m *mockAccessor) AxisRoots(context.Context) (*share.AxisRoots, error) { panic("implement me") } -func (m *mockAccessor) Sample(context.Context, int, int) (shwap.Sample, error) { +func (m *mockAccessor) Sample(context.Context, shwap.SampleCoords) (shwap.Sample, error) { panic("implement me") } diff --git a/store/cache/noop.go b/store/cache/noop.go index d777fdb2e4..27b33e2dd6 100644 --- a/store/cache/noop.go +++ b/store/cache/noop.go @@ -59,7 +59,7 @@ func (n NoopFile) AxisRoots(context.Context) (*share.AxisRoots, error) { return &share.AxisRoots{}, nil } -func (n NoopFile) Sample(context.Context, int, int) (shwap.Sample, error) { +func (n NoopFile) Sample(context.Context, shwap.SampleCoords) (shwap.Sample, error) { return shwap.Sample{}, nil } diff --git a/store/file/codec.go b/store/file/codec.go index a27280be11..f8e7b91ec9 100644 --- a/store/file/codec.go +++ b/store/file/codec.go @@ -13,7 +13,7 @@ func init() { } type Codec interface { - Encoder(len int) (reedsolomon.Encoder, error) + Encoder(ln int) (reedsolomon.Encoder, error) } type codecCache struct { @@ -24,15 +24,15 @@ func NewCodec() Codec { return &codecCache{} } -func (l *codecCache) Encoder(len int) (reedsolomon.Encoder, error) { - enc, ok := l.cache.Load(len) +func (l *codecCache) Encoder(ln int) (reedsolomon.Encoder, error) { + enc, ok := l.cache.Load(ln) if !ok { var err error - enc, err = reedsolomon.New(len/2, len/2, reedsolomon.WithLeopardGF(true)) + enc, err = reedsolomon.New(ln/2, ln/2, reedsolomon.WithLeopardGF(true)) if err != nil { return nil, err } - l.cache.Store(len, enc) + l.cache.Store(ln, enc) } return enc.(reedsolomon.Encoder), nil } diff --git a/store/file/ods.go b/store/file/ods.go index 3ec5082b83..e9b14e8a80 100644 --- a/store/file/ods.go +++ b/store/file/ods.go @@ -228,7 +228,7 @@ func (o *ODS) Close() error { // Sample returns share and corresponding proof for row and column indices. Implementation can // choose which axis to use for proof. Chosen axis for proof should be indicated in the returned // Sample. -func (o *ODS) Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sample, error) { +func (o *ODS) Sample(ctx context.Context, idx shwap.SampleCoords) (shwap.Sample, error) { // Sample proof axis is selected to optimize read performance. // - For the first and second quadrants, we read the row axis because it is more efficient to read // single row than reading full ODS to calculate single column @@ -236,6 +236,8 @@ func (o *ODS) Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sample, err // column than reading full ODS to calculate single row // - For the fourth quadrant, it does not matter which axis we read because we need to read full ODS // to calculate the sample + rowIdx, colIdx := idx.Row, idx.Col + axisType, axisIdx, shrIdx := rsmt2d.Row, rowIdx, colIdx if colIdx < o.size()/2 && rowIdx >= o.size()/2 { axisType, axisIdx, shrIdx = rsmt2d.Col, colIdx, rowIdx @@ -246,7 +248,9 @@ func (o *ODS) Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sample, err return shwap.Sample{}, fmt.Errorf("reading axis: %w", err) } - return shwap.SampleFromShares(axis, axisType, axisIdx, shrIdx) + idxNew := shwap.SampleCoords{Row: axisIdx, Col: shrIdx} + + return shwap.SampleFromShares(axis, axisType, idxNew) } // AxisHalf returns half of shares axis of the given type and index. Side is determined by diff --git a/store/file/ods_q4.go b/store/file/ods_q4.go index 06b255cae9..f0ca686094 100644 --- a/store/file/ods_q4.go +++ b/store/file/ods_q4.go @@ -122,9 +122,9 @@ func (odsq4 *ODSQ4) AxisRoots(ctx context.Context) (*share.AxisRoots, error) { return odsq4.ods.AxisRoots(ctx) } -func (odsq4 *ODSQ4) Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sample, error) { +func (odsq4 *ODSQ4) Sample(ctx context.Context, idx shwap.SampleCoords) (shwap.Sample, error) { // use native AxisHalf implementation, to read axis from q4 quadrant when possible - half, err := odsq4.AxisHalf(ctx, rsmt2d.Row, rowIdx) + half, err := odsq4.AxisHalf(ctx, rsmt2d.Row, idx.Row) if err != nil { return shwap.Sample{}, fmt.Errorf("reading axis: %w", err) } @@ -132,7 +132,8 @@ func (odsq4 *ODSQ4) Sample(ctx context.Context, rowIdx, colIdx int) (shwap.Sampl if err != nil { return shwap.Sample{}, fmt.Errorf("extending shares: %w", err) } - return shwap.SampleFromShares(shares, rsmt2d.Row, rowIdx, colIdx) + + return shwap.SampleFromShares(shares, rsmt2d.Row, idx) } func (odsq4 *ODSQ4) AxisHalf(ctx context.Context, axisType rsmt2d.Axis, axisIdx int) (eds.AxisHalf, error) { diff --git a/store/getter.go b/store/getter.go index 54ce70e4ae..1315561730 100644 --- a/store/getter.go +++ b/store/getter.go @@ -24,26 +24,29 @@ func NewGetter(store *Store) *Getter { return &Getter{store: store} } -func (g *Getter) GetShare(ctx context.Context, h *header.ExtendedHeader, row, col int) (libshare.Share, error) { - acc, err := g.store.GetByHeight(ctx, h.Height()) +func (g *Getter) GetSamples(ctx context.Context, hdr *header.ExtendedHeader, + indices []shwap.SampleCoords, +) ([]shwap.Sample, error) { + acc, err := g.store.GetByHeight(ctx, hdr.Height()) if err != nil { if errors.Is(err, ErrNotFound) { - return libshare.Share{}, shwap.ErrNotFound + return nil, shwap.ErrNotFound } - return libshare.Share{}, fmt.Errorf("get accessor from store:%w", err) + return nil, fmt.Errorf("get accessor from store:%w", err) } - logger := log.With( - "height", h.Height(), - "row", row, - "col", col, - ) - defer utils.CloseAndLog(logger, "getter/sample", acc) + defer utils.CloseAndLog(log.With("height", hdr.Height()), "getter/sample", acc) - sample, err := acc.Sample(ctx, row, col) - if err != nil { - return libshare.Share{}, fmt.Errorf("get sample from accessor:%w", err) + smpls := make([]shwap.Sample, len(indices)) + for i, idx := range indices { + smpl, err := acc.Sample(ctx, idx) + if err != nil { + return nil, fmt.Errorf("get sample from accessor:%w", err) + } + + smpls[i] = smpl } - return sample.Share, nil + + return smpls, nil } func (g *Getter) GetEDS(ctx context.Context, h *header.ExtendedHeader) (*rsmt2d.ExtendedDataSquare, error) { diff --git a/store/getter_test.go b/store/getter_test.go index 47c0ce82a4..b0b027fc19 100644 --- a/store/getter_test.go +++ b/store/getter_test.go @@ -36,14 +36,17 @@ func TestStoreGetter(t *testing.T) { squareSize := int(eds.Width()) for i := 0; i < squareSize; i++ { for j := 0; j < squareSize; j++ { - share, err := sg.GetShare(ctx, eh, i, j) + idx := shwap.SampleCoords{Row: i, Col: j} + + smpls, err := sg.GetSamples(ctx, eh, []shwap.SampleCoords{idx}) require.NoError(t, err) - require.Equal(t, eds.GetCell(uint(i), uint(j)), share.ToBytes()) + require.Equal(t, eds.GetCell(uint(i), uint(j)), smpls[0].Share.ToBytes()) } } // doesn't panic on indexes too high - _, err = sg.GetShare(ctx, eh, squareSize, squareSize) + bigIdx := squareSize * squareSize + _, err = sg.GetSamples(ctx, eh, []shwap.SampleCoords{{Row: bigIdx, Col: bigIdx}}) require.ErrorIs(t, err, shwap.ErrOutOfBounds) })