diff --git a/common/aws/cli.go b/common/aws/cli.go index d1b3bf274a..e646712175 100644 --- a/common/aws/cli.go +++ b/common/aws/cli.go @@ -37,12 +37,6 @@ type ClientConfig struct { // FragmentParallelismConstant helps determine the size of the pool of workers to help upload/download files. // A non-zero value for this parameter adds a constant number of workers. Default is 0. FragmentParallelismConstant int - // FragmentReadTimeout is used to bound the maximum time to wait for a single fragmented read. - // Default is 30 seconds. - FragmentReadTimeout time.Duration - // FragmentWriteTimeout is used to bound the maximum time to wait for a single fragmented write. - // Default is 30 seconds. - FragmentWriteTimeout time.Duration } func ClientFlags(envPrefix string, flagPrefix string) []cli.Flag { @@ -120,8 +114,6 @@ func ReadClientConfig(ctx *cli.Context, flagPrefix string) ClientConfig { EndpointURL: ctx.GlobalString(common.PrefixFlag(flagPrefix, EndpointURLFlagName)), FragmentParallelismFactor: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName)), FragmentParallelismConstant: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentParallelismConstantFlagName)), - FragmentReadTimeout: ctx.GlobalDuration(common.PrefixFlag(flagPrefix, FragmentReadTimeoutFlagName)), - FragmentWriteTimeout: ctx.GlobalDuration(common.PrefixFlag(flagPrefix, FragmentWriteTimeoutFlagName)), } } @@ -131,7 +123,5 @@ func DefaultClientConfig() *ClientConfig { Region: "us-east-2", FragmentParallelismFactor: 8, FragmentParallelismConstant: 0, - FragmentReadTimeout: 30 * time.Second, - FragmentWriteTimeout: 30 * time.Second, } } diff --git a/common/aws/s3/client.go b/common/aws/s3/client.go index 8bb35b37f3..c3ae159d41 100644 --- a/common/aws/s3/client.go +++ b/common/aws/s3/client.go @@ -241,9 +241,6 @@ func (s *client) FragmentedUploadObject( } resultChannel := make(chan error, len(fragments)) - ctx, cancel := context.WithTimeout(ctx, s.cfg.FragmentWriteTimeout) - defer cancel() - for _, fragment := range fragments { fragmentCapture := fragment s.concurrencyLimiter <- struct{}{} @@ -301,9 +298,6 @@ func (s *client) FragmentedDownloadObject( } resultChannel := make(chan *readResult, len(fragmentKeys)) - ctx, cancel := context.WithTimeout(ctx, s.cfg.FragmentWriteTimeout) - defer cancel() - for i, fragmentKey := range fragmentKeys { boundFragmentKey := fragmentKey boundI := i diff --git a/relay/auth/authenticator.go b/relay/auth/authenticator.go index a85c1f4865..2e89c83d18 100644 --- a/relay/auth/authenticator.go +++ b/relay/auth/authenticator.go @@ -17,6 +17,7 @@ type RequestAuthenticator interface { // The origin is the address of the peer that sent the request. This may be used to cache auth results // in order to save server resources. AuthenticateGetChunksRequest( + ctx context.Context, origin string, request *pb.GetChunksRequest, now time.Time) error @@ -53,6 +54,7 @@ type requestAuthenticator struct { // NewRequestAuthenticator creates a new RequestAuthenticator. func NewRequestAuthenticator( + ctx context.Context, ics core.IndexedChainState, keyCacheSize int, authenticationTimeoutDuration time.Duration) (RequestAuthenticator, error) { @@ -70,7 +72,7 @@ func NewRequestAuthenticator( keyCache: keyCache, } - err = authenticator.preloadCache() + err = authenticator.preloadCache(ctx) if err != nil { return nil, fmt.Errorf("failed to preload cache: %w", err) } @@ -78,12 +80,12 @@ func NewRequestAuthenticator( return authenticator, nil } -func (a *requestAuthenticator) preloadCache() error { +func (a *requestAuthenticator) preloadCache(ctx context.Context) error { blockNumber, err := a.ics.GetCurrentBlockNumber() if err != nil { return fmt.Errorf("failed to get current block number: %w", err) } - operators, err := a.ics.GetIndexedOperators(context.Background(), blockNumber) + operators, err := a.ics.GetIndexedOperators(ctx, blockNumber) if err != nil { return fmt.Errorf("failed to get operators: %w", err) } @@ -96,6 +98,7 @@ func (a *requestAuthenticator) preloadCache() error { } func (a *requestAuthenticator) AuthenticateGetChunksRequest( + ctx context.Context, origin string, request *pb.GetChunksRequest, now time.Time) error { @@ -105,7 +108,7 @@ func (a *requestAuthenticator) AuthenticateGetChunksRequest( return nil } - key, err := a.getOperatorKey(core.OperatorID(request.OperatorId)) + key, err := a.getOperatorKey(ctx, core.OperatorID(request.OperatorId)) if err != nil { return fmt.Errorf("failed to get operator key: %w", err) } @@ -131,7 +134,7 @@ func (a *requestAuthenticator) AuthenticateGetChunksRequest( } // getOperatorKey returns the public key of the operator with the given ID, caching the result. -func (a *requestAuthenticator) getOperatorKey(operatorID core.OperatorID) (*core.G2Point, error) { +func (a *requestAuthenticator) getOperatorKey(ctx context.Context, operatorID core.OperatorID) (*core.G2Point, error) { key, ok := a.keyCache.Get(operatorID) if ok { return key, nil @@ -141,7 +144,7 @@ func (a *requestAuthenticator) getOperatorKey(operatorID core.OperatorID) (*core if err != nil { return nil, fmt.Errorf("failed to get current block number: %w", err) } - operators, err := a.ics.GetIndexedOperators(context.Background(), blockNumber) + operators, err := a.ics.GetIndexedOperators(ctx, blockNumber) if err != nil { return nil, fmt.Errorf("failed to get operators: %w", err) } diff --git a/relay/auth/authenticator_test.go b/relay/auth/authenticator_test.go index debcbccc61..8569376984 100644 --- a/relay/auth/authenticator_test.go +++ b/relay/auth/authenticator_test.go @@ -15,6 +15,8 @@ import ( func TestMockSigning(t *testing.T) { tu.InitializeRandom() + ctx := context.Background() + operatorID := mock.MakeOperatorId(0) stakes := map[core.QuorumID]map[core.OperatorID]int{ core.QuorumID(0): { @@ -24,7 +26,7 @@ func TestMockSigning(t *testing.T) { ics, err := mock.NewChainDataMock(stakes) require.NoError(t, err) - operators, err := ics.GetIndexedOperators(context.Background(), 0) + operators, err := ics.GetIndexedOperators(ctx, 0) require.NoError(t, err) operator, ok := operators[operatorID] @@ -46,6 +48,8 @@ func TestMockSigning(t *testing.T) { func TestValidRequest(t *testing.T) { tu.InitializeRandom() + ctx := context.Background() + operatorID := mock.MakeOperatorId(0) stakes := map[core.QuorumID]map[core.OperatorID]int{ core.QuorumID(0): { @@ -58,7 +62,7 @@ func TestValidRequest(t *testing.T) { timeout := 10 * time.Second - authenticator, err := NewRequestAuthenticator(ics, 1024, timeout) + authenticator, err := NewRequestAuthenticator(ctx, ics, 1024, timeout) require.NoError(t, err) request := randomGetChunksRequest() @@ -69,6 +73,7 @@ func TestValidRequest(t *testing.T) { now := time.Now() err = authenticator.AuthenticateGetChunksRequest( + ctx, "foobar", request, now) @@ -83,12 +88,14 @@ func TestValidRequest(t *testing.T) { start := now for now.Before(start.Add(timeout)) { err = authenticator.AuthenticateGetChunksRequest( + ctx, "foobar", invalidRequest, now) require.NoError(t, err) err = authenticator.AuthenticateGetChunksRequest( + ctx, "baz", invalidRequest, now) @@ -99,6 +106,7 @@ func TestValidRequest(t *testing.T) { // After the timeout elapses, new requests should trigger authentication. err = authenticator.AuthenticateGetChunksRequest( + ctx, "foobar", invalidRequest, now) @@ -108,6 +116,8 @@ func TestValidRequest(t *testing.T) { func TestAuthenticationSavingDisabled(t *testing.T) { tu.InitializeRandom() + ctx := context.Background() + operatorID := mock.MakeOperatorId(0) stakes := map[core.QuorumID]map[core.OperatorID]int{ core.QuorumID(0): { @@ -121,7 +131,7 @@ func TestAuthenticationSavingDisabled(t *testing.T) { // This disables saving of authentication results. timeout := time.Duration(0) - authenticator, err := NewRequestAuthenticator(ics, 1024, timeout) + authenticator, err := NewRequestAuthenticator(ctx, ics, 1024, timeout) require.NoError(t, err) request := randomGetChunksRequest() @@ -132,6 +142,7 @@ func TestAuthenticationSavingDisabled(t *testing.T) { now := time.Now() err = authenticator.AuthenticateGetChunksRequest( + ctx, "foobar", request, now) @@ -144,6 +155,7 @@ func TestAuthenticationSavingDisabled(t *testing.T) { invalidRequest.OperatorSignature = signature // the previous signature is invalid here err = authenticator.AuthenticateGetChunksRequest( + ctx, "foobar", invalidRequest, now) @@ -153,6 +165,8 @@ func TestAuthenticationSavingDisabled(t *testing.T) { func TestNonExistingClient(t *testing.T) { tu.InitializeRandom() + ctx := context.Background() + operatorID := mock.MakeOperatorId(0) stakes := map[core.QuorumID]map[core.OperatorID]int{ core.QuorumID(0): { @@ -165,7 +179,7 @@ func TestNonExistingClient(t *testing.T) { timeout := 10 * time.Second - authenticator, err := NewRequestAuthenticator(ics, 1024, timeout) + authenticator, err := NewRequestAuthenticator(ctx, ics, 1024, timeout) require.NoError(t, err) invalidOperatorID := tu.RandomBytes(32) @@ -174,6 +188,7 @@ func TestNonExistingClient(t *testing.T) { request.OperatorId = invalidOperatorID err = authenticator.AuthenticateGetChunksRequest( + ctx, "foobar", request, time.Now()) @@ -183,6 +198,8 @@ func TestNonExistingClient(t *testing.T) { func TestBadSignature(t *testing.T) { tu.InitializeRandom() + ctx := context.Background() + operatorID := mock.MakeOperatorId(0) stakes := map[core.QuorumID]map[core.OperatorID]int{ core.QuorumID(0): { @@ -195,7 +212,7 @@ func TestBadSignature(t *testing.T) { timeout := 10 * time.Second - authenticator, err := NewRequestAuthenticator(ics, 1024, timeout) + authenticator, err := NewRequestAuthenticator(ctx, ics, 1024, timeout) require.NoError(t, err) request := randomGetChunksRequest() @@ -205,6 +222,7 @@ func TestBadSignature(t *testing.T) { now := time.Now() err = authenticator.AuthenticateGetChunksRequest( + ctx, "foobar", request, now) @@ -217,6 +235,7 @@ func TestBadSignature(t *testing.T) { request.OperatorSignature[0] = request.OperatorSignature[0] ^ 1 err = authenticator.AuthenticateGetChunksRequest( + ctx, "foobar", request, now) diff --git a/relay/blob_provider.go b/relay/blob_provider.go index 44157f6069..9b9863bfda 100644 --- a/relay/blob_provider.go +++ b/relay/blob_provider.go @@ -7,6 +7,7 @@ import ( "github.com/Layr-Labs/eigenda/disperser/common/v2/blobstore" "github.com/Layr-Labs/eigenda/relay/cache" "github.com/Layr-Labs/eigensdk-go/logging" + "time" ) // blobProvider encapsulates logic for fetching blobs. Utilized by the relay Server. @@ -20,6 +21,9 @@ type blobProvider struct { // blobCache is an LRU cache of blobs. blobCache cache.CachedAccessor[v2.BlobKey, []byte] + + // fetchTimeout is the maximum time to wait for a blob fetch operation to complete. + fetchTimeout time.Duration } // newBlobProvider creates a new blobProvider. @@ -28,12 +32,14 @@ func newBlobProvider( logger logging.Logger, blobStore *blobstore.BlobStore, blobCacheSize int, - maxIOConcurrency int) (*blobProvider, error) { + maxIOConcurrency int, + fetchTimeout time.Duration) (*blobProvider, error) { server := &blobProvider{ - ctx: ctx, - logger: logger, - blobStore: blobStore, + ctx: ctx, + logger: logger, + blobStore: blobStore, + fetchTimeout: fetchTimeout, } c, err := cache.NewCachedAccessor[v2.BlobKey, []byte](blobCacheSize, maxIOConcurrency, server.fetchBlob) @@ -46,9 +52,8 @@ func newBlobProvider( } // GetBlob retrieves a blob from the blob store. -func (s *blobProvider) GetBlob(blobKey v2.BlobKey) ([]byte, error) { - - data, err := s.blobCache.Get(blobKey) +func (s *blobProvider) GetBlob(ctx context.Context, blobKey v2.BlobKey) ([]byte, error) { + data, err := s.blobCache.Get(ctx, blobKey) if err != nil { // It should not be possible for external users to force an error here since we won't @@ -62,7 +67,10 @@ func (s *blobProvider) GetBlob(blobKey v2.BlobKey) ([]byte, error) { // fetchBlob retrieves a single blob from the blob store. func (s *blobProvider) fetchBlob(blobKey v2.BlobKey) ([]byte, error) { - data, err := s.blobStore.GetBlob(s.ctx, blobKey) + ctx, cancel := context.WithTimeout(s.ctx, s.fetchTimeout) + defer cancel() + + data, err := s.blobStore.GetBlob(ctx, blobKey) if err != nil { s.logger.Errorf("Failed to fetch blob: %v", err) return nil, err diff --git a/relay/blob_provider_test.go b/relay/blob_provider_test.go index 6e996977bb..9309461c65 100644 --- a/relay/blob_provider_test.go +++ b/relay/blob_provider_test.go @@ -7,6 +7,7 @@ import ( v2 "github.com/Layr-Labs/eigenda/core/v2" "github.com/stretchr/testify/require" "testing" + "time" ) func TestReadWrite(t *testing.T) { @@ -34,12 +35,18 @@ func TestReadWrite(t *testing.T) { require.NoError(t, err) } - server, err := newBlobProvider(context.Background(), logger, blobStore, 10, 32) + server, err := newBlobProvider( + context.Background(), + logger, + blobStore, + 10, + 32, + 10*time.Second) require.NoError(t, err) // Read the blobs back. for key, data := range expectedData { - blob, err := server.GetBlob(key) + blob, err := server.GetBlob(context.Background(), key) require.NoError(t, err) require.Equal(t, data, blob) @@ -47,7 +54,7 @@ func TestReadWrite(t *testing.T) { // Read the blobs back again to test caching. for key, data := range expectedData { - blob, err := server.GetBlob(key) + blob, err := server.GetBlob(context.Background(), key) require.NoError(t, err) require.Equal(t, data, blob) @@ -65,11 +72,17 @@ func TestNonExistentBlob(t *testing.T) { blobStore := buildBlobStore(t, logger) - server, err := newBlobProvider(context.Background(), logger, blobStore, 10, 32) + server, err := newBlobProvider( + context.Background(), + logger, + blobStore, + 10, + 32, + 10*time.Second) require.NoError(t, err) for i := 0; i < 10; i++ { - blob, err := server.GetBlob(v2.BlobKey(tu.RandomBytes(32))) + blob, err := server.GetBlob(context.Background(), v2.BlobKey(tu.RandomBytes(32))) require.Error(t, err) require.Nil(t, blob) } diff --git a/relay/cache/cached_accessor.go b/relay/cache/cached_accessor.go index e39a3a3910..d131229082 100644 --- a/relay/cache/cached_accessor.go +++ b/relay/cache/cached_accessor.go @@ -1,7 +1,9 @@ package cache import ( + "context" lru "github.com/hashicorp/golang-lru/v2" + "golang.org/x/sync/semaphore" "sync" ) @@ -9,7 +11,9 @@ import ( // are expensive, and prevents multiple concurrent cache misses for the same key. type CachedAccessor[K comparable, V any] interface { // Get returns the value for the given key. If the value is not in the cache, it will be fetched using the Accessor. - Get(key K) (V, error) + // If the context is cancelled, the function may abort early. If multiple goroutines request the same key, + // cancellation of one request will not affect the others. + Get(ctx context.Context, key K) (V, error) } // Accessor is function capable of fetching a value from a resource. Used by CachedAccessor when there is a cache miss. @@ -17,8 +21,8 @@ type Accessor[K comparable, V any] func(key K) (V, error) // accessResult is a struct that holds the result of an Accessor call. type accessResult[V any] struct { - // wg.Wait() will block until the value is fetched. - wg sync.WaitGroup + // sem is a semaphore used to signal that the value has been fetched. + sem *semaphore.Weighted // value is the value fetched by the Accessor, or nil if there was an error. value V // err is the error returned by the Accessor, or nil if the fetch was successful. @@ -34,7 +38,6 @@ var _ CachedAccessor[string, string] = &cachedAccessor[string, string]{} // cachedAccessor is an implementation of CachedAccessor. type cachedAccessor[K comparable, V any] struct { - // lookupsInProgress has an entry for each key that is currently being looked up via the accessor. The value // is written into the channel when it is eventually fetched. If a key is requested more than once while a // lookup in progress, the second (and following) requests will wait for the result of the first lookup @@ -86,14 +89,13 @@ func NewCachedAccessor[K comparable, V any]( func newAccessResult[V any]() *accessResult[V] { result := &accessResult[V]{ - wg: sync.WaitGroup{}, + sem: semaphore.NewWeighted(1), } - result.wg.Add(1) + _ = result.sem.Acquire(context.Background(), 1) return result } -func (c *cachedAccessor[K, V]) Get(key K) (V, error) { - +func (c *cachedAccessor[K, V]) Get(ctx context.Context, key K) (V, error) { c.cacheLock.Lock() // first, attempt to get the value from the cache @@ -114,11 +116,35 @@ func (c *cachedAccessor[K, V]) Get(key K) (V, error) { if alreadyLoading { // The result is being fetched on another goroutine. Wait for it to finish. - result.wg.Wait() - return result.value, result.err + return c.waitForResult(ctx, result) } else { // We are the first goroutine to request this key. + return c.fetchResult(ctx, key, result) + } +} +// waitForResult waits for the result of a lookup that was initiated by another requester and returns it +// when it becomes is available. This method will return quickly if the provided context is cancelled. +// Doing so does not disrupt the other requesters that are also waiting for this result. +func (c *cachedAccessor[K, V]) waitForResult(ctx context.Context, result *accessResult[V]) (V, error) { + err := result.sem.Acquire(ctx, 1) + if err != nil { + var zeroValue V + return zeroValue, err + } + + result.sem.Release(1) + return result.value, result.err +} + +// fetchResult fetches the value for the given key and returns it. If the context is cancelled before the value +// is fetched, the function will return early. If the fetch is successful, the value will be added to the cache. +func (c *cachedAccessor[K, V]) fetchResult(ctx context.Context, key K, result *accessResult[V]) (V, error) { + + // Perform the work in a background goroutine. This allows us to return early if the context is cancelled + // without disrupting the fetch operation that other requesters may be waiting for. + waitChan := make(chan struct{}, 1) + go func() { if c.concurrencyLimiter != nil { c.concurrencyLimiter <- struct{}{} } @@ -139,13 +165,22 @@ func (c *cachedAccessor[K, V]) Get(key K) (V, error) { // Provide the result to all other goroutines that may be waiting for it. result.err = err result.value = value - result.wg.Done() + result.sem.Release(1) // Clean up the lookupInProgress map. delete(c.lookupsInProgress, key) c.cacheLock.Unlock() - return value, err + waitChan <- struct{}{} + }() + + select { + case <-ctx.Done(): + // The context was cancelled before the value was fetched, possibly due to a timeout. + var zeroValue V + return zeroValue, ctx.Err() + case <-waitChan: + return result.value, result.err } } diff --git a/relay/cache/cached_accessor_test.go b/relay/cache/cached_accessor_test.go index ab37fa5a2e..9048e3d88a 100644 --- a/relay/cache/cached_accessor_test.go +++ b/relay/cache/cached_accessor_test.go @@ -1,6 +1,7 @@ package cache import ( + "context" "errors" tu "github.com/Layr-Labs/eigenda/common/testutils" "github.com/stretchr/testify/require" @@ -36,7 +37,7 @@ func TestRandomOperationsSingleThread(t *testing.T) { require.NoError(t, err) for i := 0; i < dataSize; i++ { - value, err := ca.Get(i) + value, err := ca.Get(context.Background(), i) if i%17 == 0 { require.Error(t, err) @@ -48,7 +49,7 @@ func TestRandomOperationsSingleThread(t *testing.T) { } for k, v := range baseData { - value, err := ca.Get(k) + value, err := ca.Get(context.Background(), k) if k%17 == 0 { require.Error(t, err) @@ -86,7 +87,7 @@ func TestCacheMisses(t *testing.T) { expectedCacheMissCount := uint64(0) for i := 0; i < cacheSize; i++ { expectedCacheMissCount++ - value, err := ca.Get(i) + value, err := ca.Get(context.Background(), i) require.NoError(t, err) require.Equal(t, baseData[i], *value) require.Equal(t, expectedCacheMissCount, cacheMissCount.Load()) @@ -94,7 +95,7 @@ func TestCacheMisses(t *testing.T) { // Get the first cacheSize keys again. This should not increase the cache miss count. for i := 0; i < cacheSize; i++ { - value, err := ca.Get(i) + value, err := ca.Get(context.Background(), i) require.NoError(t, err) require.Equal(t, baseData[i], *value) require.Equal(t, expectedCacheMissCount, cacheMissCount.Load()) @@ -102,14 +103,14 @@ func TestCacheMisses(t *testing.T) { // Read the last key. This should cause the first key to be evicted. expectedCacheMissCount++ - value, err := ca.Get(cacheSize) + value, err := ca.Get(context.Background(), cacheSize) require.NoError(t, err) require.Equal(t, baseData[cacheSize], *value) // Read the keys in order. Due to the order of evictions, each read should result in a cache miss. for i := 0; i < cacheSize; i++ { expectedCacheMissCount++ - value, err := ca.Get(i) + value, err := ca.Get(context.Background(), i) require.NoError(t, err) require.Equal(t, baseData[i], *value) require.Equal(t, expectedCacheMissCount, cacheMissCount.Load()) @@ -154,7 +155,7 @@ func ParallelAccessTest(t *testing.T, sleepEnabled bool) { for i := 0; i < 10; i++ { go func() { defer wg.Done() - value, err := ca.Get(0) + value, err := ca.Get(context.Background(), 0) require.NoError(t, err) require.Equal(t, baseData[0], *value) }() @@ -177,7 +178,7 @@ func ParallelAccessTest(t *testing.T, sleepEnabled bool) { require.Equal(t, uint64(1), cacheMissCount.Load()) // Fetching the key again should not result in a cache miss. - value, err := ca.Get(0) + value, err := ca.Get(context.Background(), 0) require.NoError(t, err) require.Equal(t, baseData[0], *value) require.Equal(t, uint64(1), cacheMissCount.Load()) @@ -223,7 +224,7 @@ func TestParallelAccessWithError(t *testing.T) { for i := 0; i < 10; i++ { go func() { defer wg.Done() - value, err := ca.Get(0) + value, err := ca.Get(context.Background(), 0) require.Nil(t, value) require.Equal(t, errors.New("intentional error"), err) }() @@ -246,7 +247,7 @@ func TestParallelAccessWithError(t *testing.T) { require.True(t, count >= 1) // Fetching the key again should result in another cache miss since the previous fetch failed. - value, err := ca.Get(0) + value, err := ca.Get(context.Background(), 0) require.Nil(t, value) require.Equal(t, errors.New("intentional error"), err) require.Equal(t, count+1, cacheMissCount.Load()) @@ -291,7 +292,7 @@ func TestConcurrencyLimiter(t *testing.T) { for i := 0; i < dataSize; i++ { boundI := i go func() { - value, err := ca.Get(boundI) + value, err := ca.Get(context.Background(), boundI) require.NoError(t, err) require.Equal(t, baseData[boundI], *value) wg.Done() @@ -310,3 +311,183 @@ func TestConcurrencyLimiter(t *testing.T) { accessorLock.Unlock() wg.Wait() } + +func TestOriginalRequesterTimesOut(t *testing.T) { + tu.InitializeRandom() + + dataSize := 1024 + + baseData := make(map[int]string) + for i := 0; i < dataSize; i++ { + baseData[i] = tu.RandomString(10) + } + + accessorLock := sync.RWMutex{} + cacheMissCount := atomic.Uint64{} + accessor := func(key int) (*string, error) { + + // Intentionally block if accessorLock is held by the outside scope. + // Used to provoke specific race conditions. + accessorLock.Lock() + defer accessorLock.Unlock() + + cacheMissCount.Add(1) + + str := baseData[key] + return &str, nil + } + cacheSize := rand.Intn(dataSize) + 1 + + ca, err := NewCachedAccessor(cacheSize, 0, accessor) + require.NoError(t, err) + + // Lock the accessor. This will cause all cache misses to block. + accessorLock.Lock() + + // Start several goroutines that will attempt to access the same key. + wg := sync.WaitGroup{} + wg.Add(10) + errCount := atomic.Uint64{} + for i := 0; i < 10; i++ { + + var ctx context.Context + if i == 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel() + } else { + ctx = context.Background() + } + + go func() { + defer wg.Done() + value, err := ca.Get(ctx, 0) + + if err != nil { + errCount.Add(1) + } else { + require.Equal(t, baseData[0], *value) + } + }() + + if i == 0 { + // Give the thread with the small timeout a chance to start. Although this sleep statement is + // not required for the test to pass, it makes it much more likely for this test to exercise + // the intended code pathway. + time.Sleep(100 * time.Millisecond) + } + } + + // Unlock the accessor. This will allow the goroutines to proceed. + accessorLock.Unlock() + + // Wait for the goroutines to finish. + wg.Wait() + + // Only one of the goroutines should have called into the accessor. + require.Equal(t, uint64(1), cacheMissCount.Load()) + + // At most, one goroutine should have timed out. + require.True(t, errCount.Load() <= 1) + + // Fetching the key again should not result in a cache miss. + value, err := ca.Get(context.Background(), 0) + require.NoError(t, err) + require.Equal(t, baseData[0], *value) + require.Equal(t, uint64(1), cacheMissCount.Load()) + + // The internal lookupsInProgress map should no longer contain the key. + require.Equal(t, 0, len(ca.(*cachedAccessor[int, *string]).lookupsInProgress)) +} + +func TestSecondaryRequesterTimesOut(t *testing.T) { + tu.InitializeRandom() + + dataSize := 1024 + + baseData := make(map[int]string) + for i := 0; i < dataSize; i++ { + baseData[i] = tu.RandomString(10) + } + + accessorLock := sync.RWMutex{} + cacheMissCount := atomic.Uint64{} + accessor := func(key int) (*string, error) { + + // Intentionally block if accessorLock is held by the outside scope. + // Used to provoke specific race conditions. + accessorLock.Lock() + defer accessorLock.Unlock() + + cacheMissCount.Add(1) + + str := baseData[key] + return &str, nil + } + cacheSize := rand.Intn(dataSize) + 1 + + ca, err := NewCachedAccessor(cacheSize, 0, accessor) + require.NoError(t, err) + + // Lock the accessor. This will cause all cache misses to block. + accessorLock.Lock() + + // Start several goroutines that will attempt to access the same key. + wg := sync.WaitGroup{} + wg.Add(10) + errCount := atomic.Uint64{} + for i := 0; i < 10; i++ { + + var ctx context.Context + if i == 1 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(context.Background(), 1*time.Millisecond) + defer cancel() + } else { + ctx = context.Background() + } + + go func() { + defer wg.Done() + value, err := ca.Get(ctx, 0) + + if err != nil { + errCount.Add(1) + } else { + require.Equal(t, baseData[0], *value) + } + }() + + if i == 0 { + // Give the thread with the context that won't time out a chance to start. Although this sleep statement is + // not required for the test to pass, it makes it much more likely for this test to exercise + // the intended code pathway. + time.Sleep(100 * time.Millisecond) + } + } + + // Give context a chance to time out. Although this sleep statement is not required for the test to pass, it makes + // it much more likely for this test to exercise the intended code pathway. + time.Sleep(100 * time.Millisecond) + + // Unlock the accessor. This will allow the goroutines to proceed. + accessorLock.Unlock() + + // Wait for the goroutines to finish. + wg.Wait() + + // Only one of the goroutines should have called into the accessor. + require.Equal(t, uint64(1), cacheMissCount.Load()) + + // At most, one goroutine should have timed out. + require.True(t, errCount.Load() <= 1) + + // Fetching the key again should not result in a cache miss. + value, err := ca.Get(context.Background(), 0) + require.NoError(t, err) + require.Equal(t, baseData[0], *value) + require.Equal(t, uint64(1), cacheMissCount.Load()) + + // The internal lookupsInProgress map should no longer contain the key. + require.Equal(t, 0, len(ca.(*cachedAccessor[int, *string]).lookupsInProgress)) +} diff --git a/relay/chunk_provider.go b/relay/chunk_provider.go index 3fffb42a3b..48ece7c3cd 100644 --- a/relay/chunk_provider.go +++ b/relay/chunk_provider.go @@ -11,6 +11,7 @@ import ( "github.com/Layr-Labs/eigenda/relay/chunkstore" "github.com/Layr-Labs/eigensdk-go/logging" "sync" + "time" ) type chunkProvider struct { @@ -23,6 +24,12 @@ type chunkProvider struct { // chunkReader is used to read chunks from the chunk store. chunkReader chunkstore.ChunkReader + + // fetchTimeout is the maximum time to wait for a chunk proof fetch operation to complete. + proofFetchTimeout time.Duration + + // coefficientFetchTimeout is the maximum time to wait for a chunk coefficient fetch operation to complete. + coefficientFetchTimeout time.Duration } // blobKeyWithMetadata attaches some additional metadata to a blobKey. @@ -41,12 +48,16 @@ func newChunkProvider( logger logging.Logger, chunkReader chunkstore.ChunkReader, cacheSize int, - maxIOConcurrency int) (*chunkProvider, error) { + maxIOConcurrency int, + proofFetchTimeout time.Duration, + coefficientFetchTimeout time.Duration) (*chunkProvider, error) { server := &chunkProvider{ - ctx: ctx, - logger: logger, - chunkReader: chunkReader, + ctx: ctx, + logger: logger, + chunkReader: chunkReader, + proofFetchTimeout: proofFetchTimeout, + coefficientFetchTimeout: coefficientFetchTimeout, } c, err := cache.NewCachedAccessor[blobKeyWithMetadata, []*encoding.Frame]( @@ -89,7 +100,7 @@ func (s *chunkProvider) GetFrames(ctx context.Context, mMap metadataMap) (frameM boundKey := key go func() { - frames, err := s.frameCache.Get(*boundKey) + frames, err := s.frameCache.Get(ctx, *boundKey) if err != nil { s.logger.Errorf("Failed to get frames for blob %v: %v", boundKey.blobKey, err) completionChannel <- &framesResult{ @@ -128,10 +139,13 @@ func (s *chunkProvider) fetchFrames(key blobKeyWithMetadata) ([]*encoding.Frame, var proofsErr error go func() { + ctx, cancel := context.WithTimeout(s.ctx, s.proofFetchTimeout) defer func() { wg.Done() + cancel() }() - proofs, proofsErr = s.chunkReader.GetChunkProofs(s.ctx, key.blobKey) + + proofs, proofsErr = s.chunkReader.GetChunkProofs(ctx, key.blobKey) }() fragmentInfo := &encoding.FragmentInfo{ @@ -139,7 +153,10 @@ func (s *chunkProvider) fetchFrames(key blobKeyWithMetadata) ([]*encoding.Frame, FragmentSizeBytes: key.metadata.fragmentSizeBytes, } - coefficients, err := s.chunkReader.GetChunkCoefficients(s.ctx, key.blobKey, fragmentInfo) + ctx, cancel := context.WithTimeout(s.ctx, s.coefficientFetchTimeout) + defer cancel() + + coefficients, err := s.chunkReader.GetChunkCoefficients(ctx, key.blobKey, fragmentInfo) if err != nil { return nil, err } diff --git a/relay/chunk_provider_test.go b/relay/chunk_provider_test.go index b768210d77..8615ad7d23 100644 --- a/relay/chunk_provider_test.go +++ b/relay/chunk_provider_test.go @@ -8,6 +8,7 @@ import ( "github.com/Layr-Labs/eigenda/encoding" "github.com/stretchr/testify/require" "testing" + "time" ) func TestFetchingIndividualBlobs(t *testing.T) { @@ -44,7 +45,14 @@ func TestFetchingIndividualBlobs(t *testing.T) { fragmentInfoMap[blobKey] = fragmentInfo } - server, err := newChunkProvider(context.Background(), logger, chunkReader, 10, 32) + server, err := newChunkProvider( + context.Background(), + logger, + chunkReader, + 10, + 32, + 10*time.Second, + 10*time.Second) require.NoError(t, err) // Read it back. @@ -124,7 +132,14 @@ func TestFetchingBatchedBlobs(t *testing.T) { fragmentInfoMap[blobKey] = fragmentInfo } - server, err := newChunkProvider(context.Background(), logger, chunkReader, 10, 32) + server, err := newChunkProvider( + context.Background(), + logger, + chunkReader, + 10, + 32, + 10*time.Second, + 10*time.Second) require.NoError(t, err) // Read it back. diff --git a/relay/cmd/config.go b/relay/cmd/config.go index 13f56e7ada..154c4c2bd2 100644 --- a/relay/cmd/config.go +++ b/relay/cmd/config.go @@ -85,6 +85,14 @@ func NewConfig(ctx *cli.Context) (Config, error) { AuthenticationTimeout: ctx.Duration(flags.AuthenticationTimeoutFlag.Name), AuthenticationDisabled: ctx.Bool(flags.AuthenticationDisabledFlag.Name), OnchainStateRefreshInterval: ctx.Duration(flags.OnchainStateRefreshIntervalFlag.Name), + Timeouts: relay.TimeoutConfig{ + GetChunksTimeout: ctx.Duration(flags.GetChunksTimeoutFlag.Name), + GetBlobTimeout: ctx.Duration(flags.GetBlobTimeoutFlag.Name), + InternalGetMetadataTimeout: ctx.Duration(flags.InternalGetMetadataTimeoutFlag.Name), + InternalGetBlobTimeout: ctx.Duration(flags.InternalGetBlobTimeoutFlag.Name), + InternalGetProofsTimeout: ctx.Duration(flags.InternalGetProofsTimeoutFlag.Name), + InternalGetCoefficientsTimeout: ctx.Duration(flags.InternalGetCoefficientsTimeoutFlag.Name), + }, }, EthClientConfig: geth.ReadEthClientConfig(ctx), BLSOperatorStateRetrieverAddr: ctx.String(flags.BlsOperatorStateRetrieverAddrFlag.Name), diff --git a/relay/cmd/flags/flags.go b/relay/cmd/flags/flags.go index f471765966..baed1fbcf4 100644 --- a/relay/cmd/flags/flags.go +++ b/relay/cmd/flags/flags.go @@ -231,6 +231,48 @@ var ( Required: false, EnvVar: common.PrefixEnvVar(envVarPrefix, "AUTHENTICATION_DISABLED"), } + GetChunksTimeoutFlag = cli.DurationFlag{ + Name: common.PrefixFlag(FlagPrefix, "get-chunks-timeout"), + Usage: "Timeout for GetChunks()", + EnvVar: common.PrefixEnvVar(envVarPrefix, "GET_CHUNKS_TIMEOUT"), + Required: false, + Value: 20 * time.Second, + } + GetBlobTimeoutFlag = cli.DurationFlag{ + Name: common.PrefixFlag(FlagPrefix, "get-blob-timeout"), + Usage: "Timeout for GetBlob()", + EnvVar: common.PrefixEnvVar(envVarPrefix, "GET_BLOB_TIMEOUT"), + Required: false, + Value: 20 * time.Second, + } + InternalGetMetadataTimeoutFlag = cli.DurationFlag{ + Name: common.PrefixFlag(FlagPrefix, "internal-get-metadata-timeout"), + Usage: "Timeout for internal metadata fetch", + EnvVar: common.PrefixEnvVar(envVarPrefix, "INTERNAL_GET_METADATA_TIMEOUT"), + Required: false, + Value: 5 * time.Second, + } + InternalGetBlobTimeoutFlag = cli.DurationFlag{ + Name: common.PrefixFlag(FlagPrefix, "internal-get-blob-timeout"), + Usage: "Timeout for internal blob fetch", + EnvVar: common.PrefixEnvVar(envVarPrefix, "INTERNAL_GET_BLOB_TIMEOUT"), + Required: false, + Value: 20 * time.Second, + } + InternalGetProofsTimeoutFlag = cli.DurationFlag{ + Name: common.PrefixFlag(FlagPrefix, "internal-get-proofs-timeout"), + Usage: "Timeout for internal proofs fetch", + EnvVar: common.PrefixEnvVar(envVarPrefix, "INTERNAL_GET_PROOFS_TIMEOUT"), + Required: false, + Value: 5 * time.Second, + } + InternalGetCoefficientsTimeoutFlag = cli.DurationFlag{ + Name: common.PrefixFlag(FlagPrefix, "internal-get-coefficients-timeout"), + Usage: "Timeout for internal coefficients fetch", + EnvVar: common.PrefixEnvVar(envVarPrefix, "INTERNAL_GET_COEFFICIENTS_TIMEOUT"), + Required: false, + Value: 20 * time.Second, + } OnchainStateRefreshIntervalFlag = cli.DurationFlag{ Name: common.PrefixFlag(FlagPrefix, "onchain-state-refresh-interval"), Usage: "The interval at which to refresh the onchain state", @@ -247,6 +289,8 @@ var requiredFlags = []cli.Flag{ RelayIDsFlag, BlsOperatorStateRetrieverAddrFlag, EigenDAServiceManagerAddrFlag, + AuthenticationTimeoutFlag, + AuthenticationDisabledFlag, } var optionalFlags = []cli.Flag{ @@ -276,6 +320,12 @@ var optionalFlags = []cli.Flag{ AuthenticationKeyCacheSizeFlag, AuthenticationTimeoutFlag, AuthenticationDisabledFlag, + GetChunksTimeoutFlag, + GetBlobTimeoutFlag, + InternalGetMetadataTimeoutFlag, + InternalGetBlobTimeoutFlag, + InternalGetProofsTimeoutFlag, + InternalGetCoefficientsTimeoutFlag, OnchainStateRefreshIntervalFlag, } diff --git a/relay/metadata_provider.go b/relay/metadata_provider.go index f86bccfddd..8f3f43ed86 100644 --- a/relay/metadata_provider.go +++ b/relay/metadata_provider.go @@ -10,6 +10,7 @@ import ( "github.com/Layr-Labs/eigenda/encoding" "github.com/Layr-Labs/eigenda/relay/cache" "github.com/Layr-Labs/eigensdk-go/logging" + "time" ) // Metadata about a blob. The relay only needs a small subset of a blob's metadata. @@ -41,6 +42,9 @@ type metadataProvider struct { // that are not assigned to one of these IDs. relayIDSet map[v2.RelayKey]struct{} + // fetchTimeout is the maximum time to wait for a metadata fetch operation to complete. + fetchTimeout time.Duration + // blobParamsMap is a map of blob version to blob version parameters. blobParamsMap atomic.Pointer[v2.BlobVersionParameterMap] } @@ -53,8 +57,8 @@ func newMetadataProvider( metadataCacheSize int, maxIOConcurrency int, relayIDs []v2.RelayKey, - blobParamsMap *v2.BlobVersionParameterMap, -) (*metadataProvider, error) { + fetchTimeout time.Duration, + blobParamsMap *v2.BlobVersionParameterMap) (*metadataProvider, error) { relayIDSet := make(map[v2.RelayKey]struct{}, len(relayIDs)) for _, id := range relayIDs { @@ -66,6 +70,7 @@ func newMetadataProvider( logger: logger, metadataStore: metadataStore, relayIDSet: relayIDSet, + fetchTimeout: fetchTimeout, } server.blobParamsMap.Store(blobParamsMap) @@ -89,7 +94,8 @@ type metadataMap map[v2.BlobKey]*blobMetadata // If any of the blobs do not exist, an error is returned. // Note that resulting metadata map may not have the same length as the input // keys slice if the input keys slice has duplicate items. -func (m *metadataProvider) GetMetadataForBlobs(keys []v2.BlobKey) (metadataMap, error) { +func (m *metadataProvider) GetMetadataForBlobs(ctx context.Context, keys []v2.BlobKey) (metadataMap, error) { + // blobMetadataResult is the result of a metadata fetch operation. type blobMetadataResult struct { key v2.BlobKey @@ -116,7 +122,7 @@ func (m *metadataProvider) GetMetadataForBlobs(keys []v2.BlobKey) (metadataMap, boundKey := key go func() { - metadata, err := m.metadataCache.Get(boundKey) + metadata, err := m.metadataCache.Get(ctx, boundKey) if err != nil { // Intentionally log at debug level. External users can force this condition to trigger // by requesting metadata for a blob that does not exist, and so it's important to avoid @@ -153,13 +159,16 @@ func (m *metadataProvider) UpdateBlobVersionParameters(blobParamsMap *v2.BlobVer // fetchMetadata retrieves metadata about a blob. Fetches from the cache if available, otherwise from the store. func (m *metadataProvider) fetchMetadata(key v2.BlobKey) (*blobMetadata, error) { + ctx, cancel := context.WithTimeout(m.ctx, m.fetchTimeout) + defer cancel() + blobParamsMap := m.blobParamsMap.Load() if blobParamsMap == nil { return nil, fmt.Errorf("blob version parameters is nil") } // Retrieve the metadata from the store. - cert, fragmentInfo, err := m.metadataStore.GetBlobCertificate(m.ctx, key) + cert, fragmentInfo, err := m.metadataStore.GetBlobCertificate(ctx, key) if err != nil { return nil, fmt.Errorf("error retrieving metadata for blob %s: %w", key.Hex(), err) } diff --git a/relay/metadata_provider_test.go b/relay/metadata_provider_test.go index 228ec83264..b48e157ec0 100644 --- a/relay/metadata_provider_test.go +++ b/relay/metadata_provider_test.go @@ -11,6 +11,7 @@ import ( "github.com/Layr-Labs/eigenda/encoding" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "time" ) func TestGetNonExistentBlob(t *testing.T) { @@ -23,12 +24,20 @@ func TestGetNonExistentBlob(t *testing.T) { defer teardown() metadataStore := buildMetadataStore(t) - server, err := newMetadataProvider(context.Background(), logger, metadataStore, 1024*1024, 32, nil, v2.NewBlobVersionParameterMap(mockBlobParamsMap())) + server, err := newMetadataProvider( + context.Background(), + logger, + metadataStore, + 1024*1024, + 32, + nil, + 10*time.Second, + v2.NewBlobVersionParameterMap(mockBlobParamsMap())) require.NoError(t, err) // Try to fetch a non-existent blobs for i := 0; i < 10; i++ { - _, err := server.GetMetadataForBlobs([]v2.BlobKey{v2.BlobKey(tu.RandomBytes(32))}) + _, err := server.GetMetadataForBlobs(context.Background(), []v2.BlobKey{v2.BlobKey(tu.RandomBytes(32))}) require.Error(t, err) } } @@ -81,12 +90,21 @@ func TestFetchingIndividualMetadata(t *testing.T) { require.Equal(t, fragmentSizeMap[blobKey], fragmentInfo.FragmentSizeBytes) } - server, err := newMetadataProvider(context.Background(), logger, metadataStore, 1024*1024, 32, nil, v2.NewBlobVersionParameterMap(mockBlobParamsMap())) + server, err := newMetadataProvider( + context.Background(), + logger, + metadataStore, + 1024*1024, + 32, + nil, + 10*time.Second, + v2.NewBlobVersionParameterMap(mockBlobParamsMap())) + require.NoError(t, err) // Fetch the metadata from the server. for blobKey, totalChunkSizeBytes := range totalChunkSizeMap { - mMap, err := server.GetMetadataForBlobs([]v2.BlobKey{blobKey}) + mMap, err := server.GetMetadataForBlobs(context.Background(), []v2.BlobKey{blobKey}) require.NoError(t, err) require.Equal(t, 1, len(mMap)) metadata := mMap[blobKey] @@ -97,7 +115,7 @@ func TestFetchingIndividualMetadata(t *testing.T) { // Read it back again. This uses a different code pathway due to the cache. for blobKey, totalChunkSizeBytes := range totalChunkSizeMap { - mMap, err := server.GetMetadataForBlobs([]v2.BlobKey{blobKey}) + mMap, err := server.GetMetadataForBlobs(context.Background(), []v2.BlobKey{blobKey}) require.NoError(t, err) require.Equal(t, 1, len(mMap)) metadata := mMap[blobKey] @@ -157,7 +175,15 @@ func TestBatchedFetch(t *testing.T) { require.Equal(t, fragmentSizeMap[blobKey], fragmentInfo.FragmentSizeBytes) } - server, err := newMetadataProvider(context.Background(), logger, metadataStore, 1024*1024, 32, nil, v2.NewBlobVersionParameterMap(mockBlobParamsMap())) + server, err := newMetadataProvider( + context.Background(), + logger, + metadataStore, + 1024*1024, + 32, + nil, + 10*time.Second, + v2.NewBlobVersionParameterMap(mockBlobParamsMap())) require.NoError(t, err) // Each iteration, choose a random subset of the keys to fetch @@ -171,7 +197,7 @@ func TestBatchedFetch(t *testing.T) { } } - mMap, err := server.GetMetadataForBlobs(keys) + mMap, err := server.GetMetadataForBlobs(context.Background(), keys) require.NoError(t, err) assert.Equal(t, keyCount, len(mMap)) @@ -184,7 +210,7 @@ func TestBatchedFetch(t *testing.T) { } // Test fetching with duplicate keys - mMap, err := server.GetMetadataForBlobs([]v2.BlobKey{blobKeys[0], blobKeys[0]}) + mMap, err := server.GetMetadataForBlobs(context.Background(), []v2.BlobKey{blobKeys[0], blobKeys[0]}) require.NoError(t, err) require.Equal(t, 1, len(mMap)) } @@ -255,7 +281,15 @@ func TestIndividualFetchWithSharding(t *testing.T) { require.Equal(t, fragmentSizeMap[blobKey], fragmentInfo.FragmentSizeBytes) } - server, err := newMetadataProvider(context.Background(), logger, metadataStore, 1024*1024, 32, shardList, v2.NewBlobVersionParameterMap(mockBlobParamsMap())) + server, err := newMetadataProvider( + context.Background(), + logger, + metadataStore, + 1024*1024, + 32, + shardList, + 10*time.Second, + v2.NewBlobVersionParameterMap(mockBlobParamsMap())) require.NoError(t, err) // Fetch the metadata from the server. @@ -269,7 +303,7 @@ func TestIndividualFetchWithSharding(t *testing.T) { } } - mMap, err := server.GetMetadataForBlobs([]v2.BlobKey{blobKey}) + mMap, err := server.GetMetadataForBlobs(context.Background(), []v2.BlobKey{blobKey}) if isBlobInCorrectShard { // The blob is in the relay's shard, should be returned like normal @@ -296,7 +330,7 @@ func TestIndividualFetchWithSharding(t *testing.T) { } } - mMap, err := server.GetMetadataForBlobs([]v2.BlobKey{blobKey}) + mMap, err := server.GetMetadataForBlobs(context.Background(), []v2.BlobKey{blobKey}) if isBlobInCorrectShard { // The blob is in the relay's shard, should be returned like normal @@ -379,7 +413,15 @@ func TestBatchedFetchWithSharding(t *testing.T) { require.Equal(t, fragmentSizeMap[blobKey], fragmentInfo.FragmentSizeBytes) } - server, err := newMetadataProvider(context.Background(), logger, metadataStore, 1024*1024, 32, shardList, v2.NewBlobVersionParameterMap(mockBlobParamsMap())) + server, err := newMetadataProvider( + context.Background(), + logger, + metadataStore, + 1024*1024, + 32, + shardList, + 10*time.Second, + v2.NewBlobVersionParameterMap(mockBlobParamsMap())) require.NoError(t, err) // Each iteration, choose two random keys to fetch. There will be a 25% chance that both blobs map to valid shards. @@ -409,7 +451,7 @@ func TestBatchedFetchWithSharding(t *testing.T) { } } - mMap, err := server.GetMetadataForBlobs(keys) + mMap, err := server.GetMetadataForBlobs(context.Background(), keys) if areKeysInCorrectShard { require.NoError(t, err) assert.Equal(t, keyCount, len(mMap)) diff --git a/relay/relay_test_utils.go b/relay/relay_test_utils.go index 2a075a8294..1e5120f700 100644 --- a/relay/relay_test_utils.go +++ b/relay/relay_test_utils.go @@ -3,15 +3,6 @@ package relay import ( "context" "fmt" - "log" - "math/big" - "os" - "path/filepath" - "runtime" - "strings" - "testing" - "time" - pbcommon "github.com/Layr-Labs/eigenda/api/grpc/common" pbcommonv2 "github.com/Layr-Labs/eigenda/api/grpc/common/v2" "github.com/Layr-Labs/eigenda/common" @@ -36,6 +27,13 @@ import ( "github.com/ory/dockertest/v3" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "log" + "math/big" + "os" + "path/filepath" + "runtime" + "strings" + "testing" ) var ( @@ -158,12 +156,10 @@ func buildBlobStore(t *testing.T, logger logging.Logger) *blobstore.BlobStore { func buildChunkStore(t *testing.T, logger logging.Logger) (chunkstore.ChunkReader, chunkstore.ChunkWriter) { cfg := aws.ClientConfig{ - Region: "us-east-1", - AccessKey: "localstack", - SecretAccessKey: "localstack", - EndpointURL: localstackHost, - FragmentWriteTimeout: time.Duration(10) * time.Second, - FragmentReadTimeout: time.Duration(10) * time.Second, + Region: "us-east-1", + AccessKey: "localstack", + SecretAccessKey: "localstack", + EndpointURL: localstackHost, } client, err := s3.NewClient(context.Background(), cfg, logger) @@ -199,7 +195,7 @@ func mockBlobParamsMap() map[uint8]*core.BlobVersionParameters { func randomBlob(t *testing.T) (*v2.BlobHeader, []byte) { - data := tu.RandomBytes(225) // TODO talk to Ian about this + data := tu.RandomBytes(225) data = codec.ConvertByPaddingEmptyByte(data) commitments, err := prover.GetCommitmentsForPaddedLength(data) diff --git a/relay/server.go b/relay/server.go index 32267de763..540b46b0b0 100644 --- a/relay/server.go +++ b/relay/server.go @@ -108,6 +108,9 @@ type Config struct { // AuthenticationDisabled will disable authentication if set to true. AuthenticationDisabled bool + // Timeouts contains configuration for relay timeouts. + Timeouts TimeoutConfig + // OnchainStateRefreshInterval is the interval at which the onchain state is refreshed. OnchainStateRefreshInterval time.Duration } @@ -139,8 +142,9 @@ func NewServer( config.MetadataCacheSize, config.MetadataMaxConcurrency, config.RelayIDs, - v2.NewBlobVersionParameterMap(blobParams), - ) + config.Timeouts.InternalGetMetadataTimeout, + v2.NewBlobVersionParameterMap(blobParams)) + if err != nil { return nil, fmt.Errorf("error creating metadata provider: %w", err) } @@ -150,7 +154,8 @@ func NewServer( logger, blobStore, config.BlobCacheSize, - config.BlobMaxConcurrency) + config.BlobMaxConcurrency, + config.Timeouts.InternalGetBlobTimeout) if err != nil { return nil, fmt.Errorf("error creating blob provider: %w", err) } @@ -160,7 +165,9 @@ func NewServer( logger, chunkReader, config.ChunkCacheSize, - config.ChunkMaxConcurrency) + config.ChunkMaxConcurrency, + config.Timeouts.InternalGetProofsTimeout, + config.Timeouts.InternalGetCoefficientsTimeout) if err != nil { return nil, fmt.Errorf("error creating chunk provider: %w", err) } @@ -168,6 +175,7 @@ func NewServer( var authenticator auth.RequestAuthenticator if !config.AuthenticationDisabled { authenticator, err = auth.NewRequestAuthenticator( + ctx, ics, config.AuthenticationKeyCacheSize, config.AuthenticationTimeout) @@ -190,9 +198,11 @@ func NewServer( // GetBlob retrieves a blob stored by the relay. func (s *Server) GetBlob(ctx context.Context, request *pb.GetBlobRequest) (*pb.GetBlobReply, error) { - - // TODO(cody-littley): - // - timeouts + if s.config.Timeouts.GetBlobTimeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, s.config.Timeouts.GetBlobTimeout) + defer cancel() + } err := s.blobRateLimiter.BeginGetBlobOperation(time.Now()) if err != nil { @@ -206,7 +216,7 @@ func (s *Server) GetBlob(ctx context.Context, request *pb.GetBlobRequest) (*pb.G } keys := []v2.BlobKey{key} - mMap, err := s.metadataProvider.GetMetadataForBlobs(keys) + mMap, err := s.metadataProvider.GetMetadataForBlobs(ctx, keys) if err != nil { return nil, fmt.Errorf( "error fetching metadata for blob, check if blob exists and is assigned to this relay: %w", err) @@ -221,7 +231,7 @@ func (s *Server) GetBlob(ctx context.Context, request *pb.GetBlobRequest) (*pb.G return nil, err } - data, err := s.blobProvider.GetBlob(key) + data, err := s.blobProvider.GetBlob(ctx, key) if err != nil { return nil, fmt.Errorf("error fetching blob %s: %w", key.Hex(), err) } @@ -235,9 +245,11 @@ func (s *Server) GetBlob(ctx context.Context, request *pb.GetBlobRequest) (*pb.G // GetChunks retrieves chunks from blobs stored by the relay. func (s *Server) GetChunks(ctx context.Context, request *pb.GetChunksRequest) (*pb.GetChunksReply, error) { - - // TODO(cody-littley): - // - timeouts + if s.config.Timeouts.GetChunksTimeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, s.config.Timeouts.GetChunksTimeout) + defer cancel() + } if len(request.ChunkRequests) <= 0 { return nil, fmt.Errorf("no chunk requests provided") @@ -254,7 +266,7 @@ func (s *Server) GetChunks(ctx context.Context, request *pb.GetChunksRequest) (* } clientAddress := client.Addr.String() - err := s.authenticator.AuthenticateGetChunksRequest(clientAddress, request, time.Now()) + err := s.authenticator.AuthenticateGetChunksRequest(ctx, clientAddress, request, time.Now()) if err != nil { return nil, fmt.Errorf("auth failed: %w", err) } @@ -274,7 +286,7 @@ func (s *Server) GetChunks(ctx context.Context, request *pb.GetChunksRequest) (* return nil, err } - mMap, err := s.metadataProvider.GetMetadataForBlobs(keys) + mMap, err := s.metadataProvider.GetMetadataForBlobs(ctx, keys) if err != nil { return nil, fmt.Errorf( "error fetching metadata for blob, check if blob exists and is assigned to this relay: %w", err) diff --git a/relay/server_test.go b/relay/server_test.go index baaf873150..3e16c624c3 100644 --- a/relay/server_test.go +++ b/relay/server_test.go @@ -4,6 +4,7 @@ import ( "context" "math/rand" "testing" + "time" "github.com/Layr-Labs/eigenda/relay/limiter" @@ -47,6 +48,14 @@ func defaultConfig() *Config { MaxConcurrentGetChunkOpsClient: 1, }, AuthenticationDisabled: true, + Timeouts: TimeoutConfig{ + GetBlobTimeout: 10 * time.Second, + GetChunksTimeout: 10 * time.Second, + InternalGetMetadataTimeout: 10 * time.Second, + InternalGetBlobTimeout: 10 * time.Second, + InternalGetProofsTimeout: 10 * time.Second, + InternalGetCoefficientsTimeout: 10 * time.Second, + }, } } diff --git a/relay/timeout_config.go b/relay/timeout_config.go new file mode 100644 index 0000000000..64c6be96ff --- /dev/null +++ b/relay/timeout_config.go @@ -0,0 +1,26 @@ +package relay + +import "time" + +// TimeoutConfig encapsulates the timeout configuration for the relay server. +type TimeoutConfig struct { + + // The maximum time permitted for a GetChunks GRPC to complete. If zero then no timeout is enforced. + GetChunksTimeout time.Duration + + // The maximum time permitted for a GetBlob GRPC to complete. If zero then no timeout is enforced. + GetBlobTimeout time.Duration + + // The maximum time permitted for a single request to the metadata store to fetch the metadata + // for an individual blob. + InternalGetMetadataTimeout time.Duration + + // The maximum time permitted for a single request to the blob store to fetch a blob. + InternalGetBlobTimeout time.Duration + + // The maximum time permitted for a single request to the chunk store to fetch chunk proofs. + InternalGetProofsTimeout time.Duration + + // The maximum time permitted for a single request to the chunk store to fetch chunk coefficients. + InternalGetCoefficientsTimeout time.Duration +}