diff --git a/relay/cmd/config.go b/relay/cmd/config.go index c7b8b46fcc..bb7566f5a1 100644 --- a/relay/cmd/config.go +++ b/relay/cmd/config.go @@ -2,6 +2,7 @@ package main import ( "fmt" + "github.com/Layr-Labs/eigenda/relay/limiter" "github.com/Layr-Labs/eigenda/common" "github.com/Layr-Labs/eigenda/common/aws" @@ -12,21 +13,6 @@ import ( ) // Config is the configuration for the relay Server. -// -// Environment variables are mapped into this struct by taking the name of the field in this struct, -// converting to upper case, and prepending "RELAY_". For example, "BlobCacheSize" can be set using the -// environment variable "RELAY_BLOBCACHESIZE". -// -// For nested structs, add the name of the struct variable before the field name, separated by an underscore. -// For example, "Log.Format" can be set using the environment variable "RELAY_LOG_FORMAT". -// -// Slice values can be set using a comma-separated list. For example, "RelayIDs" can be set using the environment -// variable "RELAY_RELAYIDS='1,2,3,4'". -// -// It is also possible to set the configuration using a configuration file. The path to the configuration file should -// be passed as the first argument to the relay binary, e.g. "bin/relay config.yaml". The structure of the config -// file should mirror the structure of this struct, with keys in the config file matching the field names -// of this struct. type Config struct { // Log is the configuration for the logger. Default is common.DefaultLoggerConfig(). @@ -70,6 +56,23 @@ func NewConfig(ctx *cli.Context) (Config, error) { BlobMaxConcurrency: ctx.Int(flags.BlobMaxConcurrencyFlag.Name), ChunkCacheSize: ctx.Int(flags.ChunkCacheSizeFlag.Name), ChunkMaxConcurrency: ctx.Int(flags.ChunkMaxConcurrencyFlag.Name), + RateLimits: limiter.Config{ + MaxGetBlobOpsPerSecond: ctx.Float64(flags.MaxGetBlobOpsPerSecondFlag.Name), + GetBlobOpsBurstiness: ctx.Int(flags.GetBlobOpsBurstinessFlag.Name), + MaxGetBlobBytesPerSecond: ctx.Float64(flags.MaxGetBlobBytesPerSecondFlag.Name), + GetBlobBytesBurstiness: ctx.Int(flags.GetBlobBytesBurstinessFlag.Name), + MaxConcurrentGetBlobOps: ctx.Int(flags.MaxConcurrentGetBlobOpsFlag.Name), + MaxGetChunkOpsPerSecond: ctx.Float64(flags.MaxGetChunkOpsPerSecondFlag.Name), + GetChunkOpsBurstiness: ctx.Int(flags.GetChunkOpsBurstinessFlag.Name), + MaxGetChunkBytesPerSecond: ctx.Float64(flags.MaxGetChunkBytesPerSecondFlag.Name), + GetChunkBytesBurstiness: ctx.Int(flags.GetChunkBytesBurstinessFlag.Name), + MaxConcurrentGetChunkOps: ctx.Int(flags.MaxConcurrentGetChunkOpsFlag.Name), + MaxGetChunkOpsPerSecondClient: ctx.Float64(flags.MaxGetChunkOpsPerSecondClientFlag.Name), + GetChunkOpsBurstinessClient: ctx.Int(flags.GetChunkOpsBurstinessClientFlag.Name), + MaxGetChunkBytesPerSecondClient: ctx.Float64(flags.MaxGetChunkBytesPerSecondClientFlag.Name), + GetChunkBytesBurstinessClient: ctx.Int(flags.GetChunkBytesBurstinessClientFlag.Name), + MaxConcurrentGetChunkOpsClient: ctx.Int(flags.MaxConcurrentGetChunkOpsClientFlag.Name), + }, }, } for i, id := range relayIDs { diff --git a/relay/cmd/flags/flags.go b/relay/cmd/flags/flags.go index 63e63369e5..9abd673566 100644 --- a/relay/cmd/flags/flags.go +++ b/relay/cmd/flags/flags.go @@ -85,6 +85,110 @@ var ( EnvVar: common.PrefixEnvVar(envVarPrefix, "CHUNK_MAX_CONCURRENCY"), Value: 32, } + MaxGetBlobOpsPerSecondFlag = cli.Float64Flag{ + Name: common.PrefixFlag(FlagPrefix, "max-get-blob-ops-per-second"), + Usage: "Max number of GetBlob operations per second", + Required: false, + EnvVar: common.PrefixEnvVar(envVarPrefix, "MAX_GET_BLOB_OPS_PER_SECOND"), + Value: 1024, + } + GetBlobOpsBurstinessFlag = cli.IntFlag{ + Name: common.PrefixFlag(FlagPrefix, "get-blob-ops-burstiness"), + Usage: "Burstiness of the GetBlob rate limiter", + Required: false, + EnvVar: common.PrefixEnvVar(envVarPrefix, "GET_BLOB_OPS_BURSTINESS"), + Value: 1024, + } + MaxGetBlobBytesPerSecondFlag = cli.Float64Flag{ + Name: common.PrefixFlag(FlagPrefix, "max-get-blob-bytes-per-second"), + Usage: "Max bandwidth for GetBlob operations in bytes per second", + Required: false, + EnvVar: common.PrefixEnvVar(envVarPrefix, "MAX_GET_BLOB_BYTES_PER_SECOND"), + Value: 20 * 1024 * 1024, + } + GetBlobBytesBurstinessFlag = cli.IntFlag{ + Name: common.PrefixFlag(FlagPrefix, "get-blob-bytes-burstiness"), + Usage: "Burstiness of the GetBlob bandwidth rate limiter", + Required: false, + EnvVar: common.PrefixEnvVar(envVarPrefix, "GET_BLOB_BYTES_BURSTINESS"), + Value: 20 * 1024 * 1024, + } + MaxConcurrentGetBlobOpsFlag = cli.IntFlag{ + Name: common.PrefixFlag(FlagPrefix, "max-concurrent-get-blob-ops"), + Usage: "Max number of concurrent GetBlob operations", + Required: false, + EnvVar: common.PrefixEnvVar(envVarPrefix, "MAX_CONCURRENT_GET_BLOB_OPS"), + Value: 1024, + } + MaxGetChunkOpsPerSecondFlag = cli.Float64Flag{ + Name: common.PrefixFlag(FlagPrefix, "max-get-chunk-ops-per-second"), + Usage: "Max number of GetChunk operations per second", + Required: false, + EnvVar: common.PrefixEnvVar(envVarPrefix, "MAX_GET_CHUNK_OPS_PER_SECOND"), + Value: 1024, + } + GetChunkOpsBurstinessFlag = cli.IntFlag{ + Name: common.PrefixFlag(FlagPrefix, "get-chunk-ops-burstiness"), + Usage: "Burstiness of the GetChunk rate limiter", + Required: false, + EnvVar: common.PrefixEnvVar(envVarPrefix, "GET_CHUNK_OPS_BURSTINESS"), + Value: 1024, + } + MaxGetChunkBytesPerSecondFlag = cli.Float64Flag{ + Name: common.PrefixFlag(FlagPrefix, "max-get-chunk-bytes-per-second"), + Usage: "Max bandwidth for GetChunk operations in bytes per second", + Required: false, + EnvVar: common.PrefixEnvVar(envVarPrefix, "MAX_GET_CHUNK_BYTES_PER_SECOND"), + Value: 20 * 1024 * 1024, + } + GetChunkBytesBurstinessFlag = cli.IntFlag{ + Name: common.PrefixFlag(FlagPrefix, "get-chunk-bytes-burstiness"), + Usage: "Burstiness of the GetChunk bandwidth rate limiter", + Required: false, + EnvVar: common.PrefixEnvVar(envVarPrefix, "GET_CHUNK_BYTES_BURSTINESS"), + Value: 20 * 1024 * 1024, + } + MaxConcurrentGetChunkOpsFlag = cli.IntFlag{ + Name: common.PrefixFlag(FlagPrefix, "max-concurrent-get-chunk-ops"), + Usage: "Max number of concurrent GetChunk operations", + Required: false, + EnvVar: common.PrefixEnvVar(envVarPrefix, "MAX_CONCURRENT_GET_CHUNK_OPS"), + Value: 1024, + } + MaxGetChunkOpsPerSecondClientFlag = cli.Float64Flag{ + Name: common.PrefixFlag(FlagPrefix, "max-get-chunk-ops-per-second-client"), + Usage: "Max number of GetChunk operations per second per client", + Required: false, + EnvVar: common.PrefixEnvVar(envVarPrefix, "MAX_GET_CHUNK_OPS_PER_SECOND_CLIENT"), + Value: 8, + } + GetChunkOpsBurstinessClientFlag = cli.IntFlag{ + Name: common.PrefixFlag(FlagPrefix, "get-chunk-ops-burstiness-client"), + Usage: "Burstiness of the GetChunk rate limiter per client", + Required: false, + EnvVar: common.PrefixEnvVar(envVarPrefix, "GET_CHUNK_OPS_BURSTINESS_CLIENT"), + Value: 8, + } + MaxGetChunkBytesPerSecondClientFlag = cli.Float64Flag{ + Name: common.PrefixFlag(FlagPrefix, "max-get-chunk-bytes-per-second-client"), + Usage: "Max bandwidth for GetChunk operations in bytes per second per client", + Required: false, + EnvVar: common.PrefixEnvVar(envVarPrefix, "MAX_GET_CHUNK_BYTES_PER_SECOND_CLIENT"), + Value: 2 * 1024 * 1024, + } + GetChunkBytesBurstinessClientFlag = cli.IntFlag{ + Name: common.PrefixFlag(FlagPrefix, "get-chunk-bytes-burstiness-client"), + Usage: "Burstiness of the GetChunk bandwidth rate limiter per client", + Required: false, + EnvVar: common.PrefixEnvVar(envVarPrefix, "GET_CHUNK_BYTES_BURSTINESS_CLIENT"), + } + MaxConcurrentGetChunkOpsClientFlag = cli.IntFlag{ + Name: common.PrefixFlag(FlagPrefix, "max-concurrent-get-chunk-ops-client"), + Usage: "Max number of concurrent GetChunk operations per client", + Required: false, + EnvVar: common.PrefixEnvVar(envVarPrefix, "MAX_CONCURRENT_GET_CHUNK_OPS_CLIENT"), + Value: 1, + } ) var requiredFlags = []cli.Flag{ @@ -102,6 +206,21 @@ var optionalFlags = []cli.Flag{ BlobMaxConcurrencyFlag, ChunkCacheSizeFlag, ChunkMaxConcurrencyFlag, + MaxGetBlobOpsPerSecondFlag, + GetBlobOpsBurstinessFlag, + MaxGetBlobBytesPerSecondFlag, + GetBlobBytesBurstinessFlag, + MaxConcurrentGetBlobOpsFlag, + MaxGetChunkOpsPerSecondFlag, + GetChunkOpsBurstinessFlag, + MaxGetChunkBytesPerSecondFlag, + GetChunkBytesBurstinessFlag, + MaxConcurrentGetChunkOpsFlag, + MaxGetChunkOpsPerSecondClientFlag, + GetChunkOpsBurstinessClientFlag, + MaxGetChunkBytesPerSecondClientFlag, + GetChunkBytesBurstinessClientFlag, + MaxConcurrentGetChunkOpsClientFlag, } var Flags []cli.Flag diff --git a/relay/limiter/blob_rate_limiter.go b/relay/limiter/blob_rate_limiter.go new file mode 100644 index 0000000000..0ac260cba8 --- /dev/null +++ b/relay/limiter/blob_rate_limiter.go @@ -0,0 +1,102 @@ +package limiter + +import ( + "fmt" + "golang.org/x/time/rate" + "sync" + "time" +) + +// BlobRateLimiter enforces rate limits on GetBlob operations. +type BlobRateLimiter struct { + + // config is the rate limit configuration. + config *Config + + // opLimiter enforces rate limits on the maximum rate of GetBlob operations + opLimiter *rate.Limiter + + // bandwidthLimiter enforces rate limits on the maximum bandwidth consumed by GetBlob operations. Only the size + // of the blob data is considered, not the size of the entire response. + bandwidthLimiter *rate.Limiter + + // operationsInFlight is the number of GetBlob operations currently in flight. + operationsInFlight int + + // this lock is used to provide thread safety + lock sync.Mutex +} + +// NewBlobRateLimiter creates a new BlobRateLimiter. +func NewBlobRateLimiter(config *Config) *BlobRateLimiter { + globalGetBlobOpLimiter := rate.NewLimiter( + rate.Limit(config.MaxGetBlobOpsPerSecond), + config.GetBlobOpsBurstiness) + + globalGetBlobBandwidthLimiter := rate.NewLimiter( + rate.Limit(config.MaxGetBlobBytesPerSecond), + config.GetBlobBytesBurstiness) + + return &BlobRateLimiter{ + config: config, + opLimiter: globalGetBlobOpLimiter, + bandwidthLimiter: globalGetBlobBandwidthLimiter, + } +} + +// BeginGetBlobOperation should be called when a GetBlob operation is about to begin. If it returns an error, +// the operation should not be performed. If it does not return an error, FinishGetBlobOperation should be +// called when the operation completes. +func (l *BlobRateLimiter) BeginGetBlobOperation(now time.Time) error { + if l == nil { + // If the rate limiter is nil, do not enforce rate limits. + return nil + } + + l.lock.Lock() + defer l.lock.Unlock() + + if l.operationsInFlight >= l.config.MaxConcurrentGetBlobOps { + return fmt.Errorf("global concurrent request limit exceeded for getBlob operations, try again later") + } + if l.opLimiter.TokensAt(now) < 1 { + return fmt.Errorf("global rate limit exceeded for getBlob operations, try again later") + } + + l.operationsInFlight++ + l.opLimiter.AllowN(now, 1) + + return nil +} + +// FinishGetBlobOperation should be called exactly once for each time BeginGetBlobOperation is called and +// returns nil. +func (l *BlobRateLimiter) FinishGetBlobOperation() { + if l == nil { + // If the rate limiter is nil, do not enforce rate limits. + return + } + + l.lock.Lock() + defer l.lock.Unlock() + + l.operationsInFlight-- +} + +// RequestGetBlobBandwidth should be called when a GetBlob is about to start downloading blob data +// from S3. It returns an error if there is insufficient bandwidth available. If it returns nil, the +// operation should proceed. +func (l *BlobRateLimiter) RequestGetBlobBandwidth(now time.Time, bytes uint32) error { + if l == nil { + // If the rate limiter is nil, do not enforce rate limits. + return nil + } + + // no locking needed, the only thing we touch here is the bandwidthLimiter, which is inherently thread-safe + + allowed := l.bandwidthLimiter.AllowN(now, int(bytes)) + if !allowed { + return fmt.Errorf("global rate limit exceeded for getBlob bandwidth, try again later") + } + return nil +} diff --git a/relay/limiter/blob_rate_limiter_test.go b/relay/limiter/blob_rate_limiter_test.go new file mode 100644 index 0000000000..2966b6bea0 --- /dev/null +++ b/relay/limiter/blob_rate_limiter_test.go @@ -0,0 +1,163 @@ +package limiter + +import ( + tu "github.com/Layr-Labs/eigenda/common/testutils" + "github.com/stretchr/testify/require" + "golang.org/x/exp/rand" + "testing" + "time" +) + +func defaultConfig() *Config { + return &Config{ + MaxGetBlobOpsPerSecond: 1024, + GetBlobOpsBurstiness: 1024, + MaxGetBlobBytesPerSecond: 20 * 1024 * 1024, + GetBlobBytesBurstiness: 20 * 1024 * 1024, + MaxConcurrentGetBlobOps: 1024, + MaxGetChunkOpsPerSecond: 1024, + GetChunkOpsBurstiness: 1024, + MaxGetChunkBytesPerSecond: 20 * 1024 * 1024, + GetChunkBytesBurstiness: 20 * 1024 * 1024, + MaxConcurrentGetChunkOps: 1024, + MaxGetChunkOpsPerSecondClient: 8, + GetChunkOpsBurstinessClient: 8, + MaxGetChunkBytesPerSecondClient: 2 * 1024 * 1024, + GetChunkBytesBurstinessClient: 2 * 1024 * 1024, + MaxConcurrentGetChunkOpsClient: 1, + } +} + +func TestConcurrentBlobOperations(t *testing.T) { + tu.InitializeRandom() + + concurrencyLimit := 1 + rand.Intn(10) + + config := defaultConfig() + config.MaxConcurrentGetBlobOps = concurrencyLimit + // Make the burstiness limit high enough that we won't be rate limited + config.GetBlobOpsBurstiness = concurrencyLimit * 100 + + limiter := NewBlobRateLimiter(config) + + // time starts at current time, but advances manually afterward + now := time.Now() + + // We should be able to start this many operations concurrently + for i := 0; i < concurrencyLimit; i++ { + err := limiter.BeginGetBlobOperation(now) + require.NoError(t, err) + } + + // Starting one more operation should fail due to the concurrency limit + err := limiter.BeginGetBlobOperation(now) + require.Error(t, err) + + // Finish an operation. This should permit exactly one more operation to start + limiter.FinishGetBlobOperation() + err = limiter.BeginGetBlobOperation(now) + require.NoError(t, err) + err = limiter.BeginGetBlobOperation(now) + require.Error(t, err) +} + +func TestGetBlobOpRateLimit(t *testing.T) { + tu.InitializeRandom() + + config := defaultConfig() + config.MaxGetBlobOpsPerSecond = float64(2 + rand.Intn(10)) + config.GetBlobOpsBurstiness = int(config.MaxGetBlobOpsPerSecond) + rand.Intn(10) + config.MaxConcurrentGetBlobOps = 1 + + limiter := NewBlobRateLimiter(config) + + // time starts at current time, but advances manually afterward + now := time.Now() + + // Without advancing time, we should be able to perform a number of operations equal to the burstiness limit. + for i := 0; i < config.GetBlobOpsBurstiness; i++ { + err := limiter.BeginGetBlobOperation(now) + require.NoError(t, err) + limiter.FinishGetBlobOperation() + } + + // We are not at the rate limit, and should be able to start another operation. + err := limiter.BeginGetBlobOperation(now) + require.Error(t, err) + + // Advance time by one second. We should gain a number of tokens equal to the rate limit. + now = now.Add(time.Second) + for i := 0; i < int(config.MaxGetBlobOpsPerSecond); i++ { + err = limiter.BeginGetBlobOperation(now) + require.NoError(t, err) + limiter.FinishGetBlobOperation() + } + + // We have once again hit the rate limit. We should not be able to start another operation. + err = limiter.BeginGetBlobOperation(now) + require.Error(t, err) + + // Advance time by another second. We should gain another number of tokens equal to the rate limit. + // Intentionally do not finish the next operation. We are attempting to get a failure by exceeding + // the max concurrent operations limit. + now = now.Add(time.Second) + err = limiter.BeginGetBlobOperation(now) + require.NoError(t, err) + + // This operation should fail since we have limited concurrent operations to 1. It should not count + // against the rate limit. + err = limiter.BeginGetBlobOperation(now) + require.Error(t, err) + + // "finish" the prior operation. Verify that we have all expected tokens available. + limiter.FinishGetBlobOperation() + for i := 0; i < int(config.MaxGetBlobOpsPerSecond)-1; i++ { + err = limiter.BeginGetBlobOperation(now) + require.NoError(t, err) + limiter.FinishGetBlobOperation() + } + + // We should now be at the rate limit. We should not be able to start another operation. + err = limiter.BeginGetBlobOperation(now) + require.Error(t, err) +} + +func TestGetBlobBandwidthLimit(t *testing.T) { + tu.InitializeRandom() + + config := defaultConfig() + config.MaxGetBlobBytesPerSecond = float64(1024 + rand.Intn(1024*1024)) + config.GetBlobBytesBurstiness = int(config.MaxGetBlobBytesPerSecond) + rand.Intn(1024*1024) + + limiter := NewBlobRateLimiter(config) + + // time starts at current time, but advances manually afterward + now := time.Now() + + // Without advancing time, we should be able to utilize a number of bytes equal to the burstiness limit. + bytesRemaining := config.GetBlobBytesBurstiness + for bytesRemaining > 0 { + bytesToRequest := 1 + rand.Intn(bytesRemaining) + err := limiter.RequestGetBlobBandwidth(now, uint32(bytesToRequest)) + require.NoError(t, err) + bytesRemaining -= bytesToRequest + } + + // Requesting one more byte should fail due to the bandwidth limit + err := limiter.RequestGetBlobBandwidth(now, 1) + require.Error(t, err) + + // Advance time by one second. We should gain a number of tokens equal to the rate limit. + now = now.Add(time.Second) + bytesRemaining = int(config.MaxGetBlobBytesPerSecond) + for bytesRemaining > 0 { + bytesToRequest := 1 + rand.Intn(bytesRemaining) + err = limiter.RequestGetBlobBandwidth(now, uint32(bytesToRequest)) + require.NoError(t, err) + bytesRemaining -= bytesToRequest + } + + // Requesting one more byte should fail due to the bandwidth limit + err = limiter.RequestGetBlobBandwidth(now, 1) + require.Error(t, err) +} diff --git a/relay/limiter/chunk_rate_limiter.go b/relay/limiter/chunk_rate_limiter.go new file mode 100644 index 0000000000..fe899e5b17 --- /dev/null +++ b/relay/limiter/chunk_rate_limiter.go @@ -0,0 +1,151 @@ +package limiter + +import ( + "fmt" + "golang.org/x/time/rate" + "sync" + "time" +) + +// ChunkRateLimiter enforces rate limits on GetChunk operations. +type ChunkRateLimiter struct { + + // config is the rate limit configuration. + config *Config + + // global limiters + + // globalOpLimiter enforces global rate limits on the maximum rate of GetChunk operations + globalOpLimiter *rate.Limiter + + // globalBandwidthLimiter enforces global rate limits on the maximum bandwidth consumed by GetChunk operations. + globalBandwidthLimiter *rate.Limiter + + // globalOperationsInFlight is the number of GetChunk operations currently in flight. + globalOperationsInFlight int + + // per-client limiters + + // Note: in its current form, these expose a DOS vector, since an attacker can create many clients IDs + // and force these maps to become arbitrarily large. This will be remedied when authentication + // is implemented, as only authentication will happen prior to rate limiting. + + // perClientOpLimiter enforces per-client rate limits on the maximum rate of GetChunk operations + perClientOpLimiter map[string]*rate.Limiter + + // perClientBandwidthLimiter enforces per-client rate limits on the maximum bandwidth consumed by + // GetChunk operations. + perClientBandwidthLimiter map[string]*rate.Limiter + + // perClientOperationsInFlight is the number of GetChunk operations currently in flight for each client. + perClientOperationsInFlight map[string]int + + // this lock is used to provide thread safety + lock sync.Mutex +} + +// NewChunkRateLimiter creates a new ChunkRateLimiter. +func NewChunkRateLimiter(config *Config) *ChunkRateLimiter { + + globalOpLimiter := rate.NewLimiter(rate.Limit( + config.MaxGetChunkOpsPerSecond), + config.GetChunkOpsBurstiness) + + globalBandwidthLimiter := rate.NewLimiter(rate.Limit( + config.MaxGetChunkBytesPerSecond), + config.GetChunkBytesBurstiness) + + return &ChunkRateLimiter{ + config: config, + globalOpLimiter: globalOpLimiter, + globalBandwidthLimiter: globalBandwidthLimiter, + perClientOpLimiter: make(map[string]*rate.Limiter), + perClientBandwidthLimiter: make(map[string]*rate.Limiter), + perClientOperationsInFlight: make(map[string]int), + } +} + +// BeginGetChunkOperation should be called when a GetChunk operation is about to begin. If it returns an error, +// the operation should not be performed. If it does not return an error, FinishGetChunkOperation should be +// called when the operation completes. +func (l *ChunkRateLimiter) BeginGetChunkOperation( + now time.Time, + requesterID string) error { + if l == nil { + // If the rate limiter is nil, do not enforce rate limits. + return nil + } + + l.lock.Lock() + defer l.lock.Unlock() + + _, ok := l.perClientOperationsInFlight[requesterID] + if !ok { + // This is the first time we've seen this client ID. + l.perClientOperationsInFlight[requesterID] = 0 + + l.perClientOpLimiter[requesterID] = rate.NewLimiter( + rate.Limit(l.config.MaxGetChunkOpsPerSecondClient), + l.config.GetChunkOpsBurstinessClient) + + l.perClientBandwidthLimiter[requesterID] = rate.NewLimiter( + rate.Limit(l.config.MaxGetChunkBytesPerSecondClient), + l.config.GetChunkBytesBurstinessClient) + } + + if l.globalOperationsInFlight >= l.config.MaxConcurrentGetChunkOps { + return fmt.Errorf("global concurrent request limit exceeded for GetChunks operations, try again later") + } + if l.globalOpLimiter.TokensAt(now) < 1 { + return fmt.Errorf("global rate limit exceeded for GetChunks operations, try again later") + } + if l.perClientOperationsInFlight[requesterID] >= l.config.MaxConcurrentGetChunkOpsClient { + return fmt.Errorf("client concurrent request limit exceeded for GetChunks") + } + if l.perClientOpLimiter[requesterID].TokensAt(now) < 1 { + return fmt.Errorf("client rate limit exceeded for GetChunks, try again later") + } + + l.globalOperationsInFlight++ + l.perClientOperationsInFlight[requesterID]++ + l.globalOpLimiter.AllowN(now, 1) + l.perClientOpLimiter[requesterID].AllowN(now, 1) + + return nil +} + +// FinishGetChunkOperation should be called when a GetChunk operation completes. +func (l *ChunkRateLimiter) FinishGetChunkOperation(requesterID string) { + if l == nil { + return + } + + l.lock.Lock() + defer l.lock.Unlock() + + l.globalOperationsInFlight-- + l.perClientOperationsInFlight[requesterID]-- +} + +// RequestGetChunkBandwidth should be called when a GetChunk is about to start downloading chunk data. +func (l *ChunkRateLimiter) RequestGetChunkBandwidth(now time.Time, requesterID string, bytes int) error { + if l == nil { + // If the rate limiter is nil, do not enforce rate limits. + return nil + } + + // no lock needed here, as the bandwidth limiters themselves are thread-safe + + allowed := l.globalBandwidthLimiter.AllowN(now, bytes) + if !allowed { + return fmt.Errorf("global rate limit exceeded for GetChunk bandwidth, try again later") + } + + allowed = l.perClientBandwidthLimiter[requesterID].AllowN(now, bytes) + if !allowed { + l.globalBandwidthLimiter.AllowN(now, -bytes) + return fmt.Errorf("client rate limit exceeded for GetChunk bandwidth, try again later") + } + + return nil +} diff --git a/relay/limiter/chunk_rate_limiter_test.go b/relay/limiter/chunk_rate_limiter_test.go new file mode 100644 index 0000000000..59399ca17f --- /dev/null +++ b/relay/limiter/chunk_rate_limiter_test.go @@ -0,0 +1,335 @@ +package limiter + +import ( + tu "github.com/Layr-Labs/eigenda/common/testutils" + "github.com/stretchr/testify/require" + "golang.org/x/exp/rand" + "math" + "testing" + "time" +) + +func TestConcurrentGetChunksOperations(t *testing.T) { + tu.InitializeRandom() + + concurrencyLimit := 1 + rand.Intn(10) + + config := defaultConfig() + config.MaxConcurrentGetChunkOps = concurrencyLimit + config.MaxConcurrentGetChunkOpsClient = math.MaxInt32 + config.GetChunkOpsBurstiness = math.MaxInt32 + config.GetChunkOpsBurstinessClient = math.MaxInt32 + + userID := tu.RandomString(64) + + limiter := NewChunkRateLimiter(config) + + // time starts at current time, but advances manually afterward + now := time.Now() + + // We should be able to start this many operations concurrently + for i := 0; i < concurrencyLimit; i++ { + err := limiter.BeginGetChunkOperation(now, userID) + require.NoError(t, err) + } + + // Starting one more operation should fail due to the concurrency limit + err := limiter.BeginGetChunkOperation(now, userID) + require.Error(t, err) + + // Finish an operation. This should permit exactly one more operation to start + limiter.FinishGetChunkOperation(userID) + err = limiter.BeginGetChunkOperation(now, userID) + require.NoError(t, err) + err = limiter.BeginGetChunkOperation(now, userID) + require.Error(t, err) +} + +func TestGetChunksRateLimit(t *testing.T) { + tu.InitializeRandom() + + config := defaultConfig() + config.MaxGetChunkOpsPerSecond = float64(2 + rand.Intn(10)) + config.GetChunkOpsBurstiness = int(config.MaxGetChunkOpsPerSecond) + rand.Intn(10) + config.GetChunkOpsBurstinessClient = math.MaxInt32 + config.MaxConcurrentGetChunkOps = 1 + + userID := tu.RandomString(64) + + limiter := NewChunkRateLimiter(config) + + // time starts at current time, but advances manually afterward + now := time.Now() + + // Without advancing time, we should be able to perform a number of operations equal to the burstiness limit. + for i := 0; i < config.GetChunkOpsBurstiness; i++ { + err := limiter.BeginGetChunkOperation(now, userID) + require.NoError(t, err) + limiter.FinishGetChunkOperation(userID) + } + + // We are now at the rate limit, and should not be able to start another operation. + err := limiter.BeginGetChunkOperation(now, userID) + require.Error(t, err) + + // Advance time by one second. We should now be able to perform a number of operations equal to the rate limit. + now = now.Add(time.Second) + for i := 0; i < int(config.MaxGetChunkOpsPerSecond); i++ { + err = limiter.BeginGetChunkOperation(now, userID) + require.NoError(t, err) + limiter.FinishGetChunkOperation(userID) + } + + // We are now at the rate limit, and should not be able to start another operation. + err = limiter.BeginGetChunkOperation(now, userID) + require.Error(t, err) + + // Advance time by one second. + // Intentionally do not finish the operation. We are attempting to see what happens when an operation fails + // due to the limit on parallel operations. + now = now.Add(time.Second) + err = limiter.BeginGetChunkOperation(now, userID) + require.NoError(t, err) + + // This operation will fail due to the concurrency limit. It should not affect the rate limit. + err = limiter.BeginGetChunkOperation(now, userID) + require.Error(t, err) + + // Finish the operation that was started in the previous second. This should permit the next operation to start. + limiter.FinishGetChunkOperation(userID) + + // Verify that we have the expected number of available tokens. + for i := 0; i < int(config.MaxGetChunkOpsPerSecond)-1; i++ { + err = limiter.BeginGetChunkOperation(now, userID) + require.NoError(t, err) + limiter.FinishGetChunkOperation(userID) + } + + // We are now at the rate limit, and should not be able to start another operation. + err = limiter.BeginGetChunkOperation(now, userID) + require.Error(t, err) +} + +func TestGetChunksBandwidthLimit(t *testing.T) { + tu.InitializeRandom() + + config := defaultConfig() + config.MaxGetChunkBytesPerSecond = float64(1024 + rand.Intn(1024*1024)) + config.GetChunkBytesBurstiness = int(config.MaxGetBlobBytesPerSecond) + rand.Intn(1024*1024) + config.GetChunkBytesBurstinessClient = math.MaxInt32 + + userID := tu.RandomString(64) + + limiter := NewChunkRateLimiter(config) + + // time starts at current time, but advances manually afterward + now := time.Now() + + // "register" the user ID + err := limiter.BeginGetChunkOperation(now, userID) + require.NoError(t, err) + limiter.FinishGetChunkOperation(userID) + + // Without advancing time, we should be able to utilize a number of bytes equal to the burstiness limit. + bytesRemaining := config.GetChunkBytesBurstiness + for bytesRemaining > 0 { + bytesToRequest := 1 + rand.Intn(bytesRemaining) + err = limiter.RequestGetChunkBandwidth(now, userID, bytesToRequest) + require.NoError(t, err) + bytesRemaining -= bytesToRequest + } + + // Requesting one more byte should fail due to the bandwidth limit + err = limiter.RequestGetChunkBandwidth(now, userID, 1) + require.Error(t, err) + + // Advance time by one second. We should gain a number of tokens equal to the rate limit. + now = now.Add(time.Second) + bytesRemaining = int(config.MaxGetChunkBytesPerSecond) + for bytesRemaining > 0 { + bytesToRequest := 1 + rand.Intn(bytesRemaining) + err = limiter.RequestGetChunkBandwidth(now, userID, bytesToRequest) + require.NoError(t, err) + bytesRemaining -= bytesToRequest + } + + // Requesting one more byte should fail due to the bandwidth limit + err = limiter.RequestGetChunkBandwidth(now, userID, 1) + require.Error(t, err) +} + +func TestPerClientConcurrencyLimit(t *testing.T) { + tu.InitializeRandom() + + config := defaultConfig() + config.MaxConcurrentGetChunkOpsClient = 1 + rand.Intn(10) + config.MaxConcurrentGetChunkOps = 2 * config.MaxConcurrentGetChunkOpsClient + config.GetChunkOpsBurstinessClient = math.MaxInt32 + config.GetChunkOpsBurstiness = math.MaxInt32 + + userID1 := tu.RandomString(64) + userID2 := tu.RandomString(64) + + limiter := NewChunkRateLimiter(config) + + // time starts at current time, but advances manually afterward + now := time.Now() + + // Start the maximum permitted number of operations for user 1 + for i := 0; i < config.MaxConcurrentGetChunkOpsClient; i++ { + err := limiter.BeginGetChunkOperation(now, userID1) + require.NoError(t, err) + } + + // Starting another operation for user 1 should fail due to the concurrency limit + err := limiter.BeginGetChunkOperation(now, userID1) + require.Error(t, err) + + // The failure to start the operation for client 1 should not use up any of the global concurrency slots. + // To verify this, allow the maximum number of operations for client 2 to start. + for i := 0; i < config.MaxConcurrentGetChunkOpsClient; i++ { + err := limiter.BeginGetChunkOperation(now, userID2) + require.NoError(t, err) + } + + // Starting another operation for client 2 should fail due to the concurrency limit + err = limiter.BeginGetChunkOperation(now, userID2) + require.Error(t, err) + + // Ending an operation from client 2 should not affect the concurrency limit for client 1. + limiter.FinishGetChunkOperation(userID2) + err = limiter.BeginGetChunkOperation(now, userID1) + require.Error(t, err) + + // Ending an operation from client 1 should permit another operation for client 1 to start. + limiter.FinishGetChunkOperation(userID1) + err = limiter.BeginGetChunkOperation(now, userID1) + require.NoError(t, err) +} + +func TestOpLimitPerClient(t *testing.T) { + tu.InitializeRandom() + + config := defaultConfig() + config.MaxGetChunkOpsPerSecondClient = float64(2 + rand.Intn(10)) + config.GetChunkOpsBurstinessClient = int(config.MaxGetChunkOpsPerSecondClient) + rand.Intn(10) + config.GetChunkOpsBurstiness = math.MaxInt32 + + userID1 := tu.RandomString(64) + userID2 := tu.RandomString(64) + + limiter := NewChunkRateLimiter(config) + + // time starts at current time, but advances manually afterward + now := time.Now() + + // Without advancing time, we should be able to perform a number of operations equal to the burstiness limit. + for i := 0; i < config.GetChunkOpsBurstinessClient; i++ { + err := limiter.BeginGetChunkOperation(now, userID1) + require.NoError(t, err) + limiter.FinishGetChunkOperation(userID1) + } + + // We are not at the rate limit, and should be able to start another operation. + err := limiter.BeginGetChunkOperation(now, userID1) + require.Error(t, err) + + // Client 2 should not be rate limited based on actions by client 1. + for i := 0; i < config.GetChunkOpsBurstinessClient; i++ { + err := limiter.BeginGetChunkOperation(now, userID2) + require.NoError(t, err) + limiter.FinishGetChunkOperation(userID2) + } + + // Client 2 should now have exhausted its burstiness limit. + err = limiter.BeginGetChunkOperation(now, userID2) + require.Error(t, err) + + // Advancing time by a second should permit more operations. + now = now.Add(time.Second) + for i := 0; i < int(config.MaxGetChunkOpsPerSecondClient); i++ { + err = limiter.BeginGetChunkOperation(now, userID1) + require.NoError(t, err) + limiter.FinishGetChunkOperation(userID1) + err = limiter.BeginGetChunkOperation(now, userID2) + require.NoError(t, err) + limiter.FinishGetChunkOperation(userID2) + } + + // No more operations should be permitted for either client. + err = limiter.BeginGetChunkOperation(now, userID1) + require.Error(t, err) + err = limiter.BeginGetChunkOperation(now, userID2) + require.Error(t, err) +} + +func TestBandwidthLimitPerClient(t *testing.T) { + tu.InitializeRandom() + + config := defaultConfig() + config.MaxGetChunkBytesPerSecondClient = float64(1024 + rand.Intn(1024*1024)) + config.GetChunkBytesBurstinessClient = int(config.MaxGetBlobBytesPerSecond) + rand.Intn(1024*1024) + config.GetChunkBytesBurstiness = math.MaxInt32 + config.GetChunkOpsBurstiness = math.MaxInt32 + config.GetChunkOpsBurstinessClient = math.MaxInt32 + + userID1 := tu.RandomString(64) + userID2 := tu.RandomString(64) + + limiter := NewChunkRateLimiter(config) + + // time starts at current time, but advances manually afterward + now := time.Now() + + // "register" the user IDs + err := limiter.BeginGetChunkOperation(now, userID1) + require.NoError(t, err) + limiter.FinishGetChunkOperation(userID1) + err = limiter.BeginGetChunkOperation(now, userID2) + require.NoError(t, err) + limiter.FinishGetChunkOperation(userID2) + + // Request maximum possible bandwidth for client 1 + bytesRemaining := config.GetChunkBytesBurstinessClient + for bytesRemaining > 0 { + bytesToRequest := 1 + rand.Intn(bytesRemaining) + err = limiter.RequestGetChunkBandwidth(now, userID1, bytesToRequest) + require.NoError(t, err) + bytesRemaining -= bytesToRequest + } + + // Requesting one more byte should fail due to the bandwidth limit + err = limiter.RequestGetChunkBandwidth(now, userID1, 1) + require.Error(t, err) + + // User 2 should have its full bandwidth allowance available + bytesRemaining = config.GetChunkBytesBurstinessClient + for bytesRemaining > 0 { + bytesToRequest := 1 + rand.Intn(bytesRemaining) + err = limiter.RequestGetChunkBandwidth(now, userID2, bytesToRequest) + require.NoError(t, err) + bytesRemaining -= bytesToRequest + } + + // Requesting one more byte should fail due to the bandwidth limit + err = limiter.RequestGetChunkBandwidth(now, userID2, 1) + require.Error(t, err) + + // Advance time by one second. We should gain a number of tokens equal to the rate limit. + now = now.Add(time.Second) + bytesRemaining = int(config.MaxGetChunkBytesPerSecondClient) + for bytesRemaining > 0 { + bytesToRequest := 1 + rand.Intn(bytesRemaining) + err = limiter.RequestGetChunkBandwidth(now, userID1, bytesToRequest) + require.NoError(t, err) + err = limiter.RequestGetChunkBandwidth(now, userID2, bytesToRequest) + require.NoError(t, err) + bytesRemaining -= bytesToRequest + } + + // All bandwidth should now be exhausted for both clients + err = limiter.RequestGetChunkBandwidth(now, userID1, 1) + require.Error(t, err) + err = limiter.RequestGetChunkBandwidth(now, userID2, 1) + require.Error(t, err) +} diff --git a/relay/limiter/config.go b/relay/limiter/config.go new file mode 100644 index 0000000000..5f19d9362a --- /dev/null +++ b/relay/limiter/config.go @@ -0,0 +1,65 @@ +package limiter + +// Config is the configuration for the relay rate limiting. +type Config struct { + + // Blob rate limiting + + // MaxGetBlobOpsPerSecond is the maximum permitted number of GetBlob operations per second. Default is + // 1024. + MaxGetBlobOpsPerSecond float64 + // The burstiness of the MaxGetBlobOpsPerSecond rate limiter. This is the maximum burst size that happen within + // a short time window. Default is 1024. + GetBlobOpsBurstiness int + + // MaxGetBlobBytesPerSecond is the maximum bandwidth, in bytes, that GetBlob operations are permitted + // to consume per second. Default is 20MiB/s. + MaxGetBlobBytesPerSecond float64 + // The burstiness of the MaxGetBlobBytesPerSecond rate limiter. This is the maximum burst size that happen within + // a short time window. Default is 20MiB. + GetBlobBytesBurstiness int + + // MaxConcurrentGetBlobOps is the maximum number of concurrent GetBlob operations that are permitted. + // This is in addition to the rate limits. Default is 1024. + MaxConcurrentGetBlobOps int + + // Chunk rate limiting + + // MaxGetChunkOpsPerSecond is the maximum permitted number of GetChunk operations per second. Default is + // 1024. + MaxGetChunkOpsPerSecond float64 + // The burstiness of the MaxGetChunkOpsPerSecond rate limiter. This is the maximum burst size that happen within + // a short time window. Default is 1024. + GetChunkOpsBurstiness int + + // MaxGetChunkBytesPerSecond is the maximum bandwidth, in bytes, that GetChunk operations are permitted + // to consume per second. Default is 20MiB/s. + MaxGetChunkBytesPerSecond float64 + // The burstiness of the MaxGetChunkBytesPerSecond rate limiter. This is the maximum burst size that happen within + // a short time window. Default is 20MiB. + GetChunkBytesBurstiness int + + // MaxConcurrentGetChunkOps is the maximum number of concurrent GetChunk operations that are permitted. + // Default is 1024. + MaxConcurrentGetChunkOps int + + // Client rate limiting for GetChunk operations + + // MaxGetChunkOpsPerSecondClient is the maximum permitted number of GetChunk operations per second for a single + // client. Default is 8. + MaxGetChunkOpsPerSecondClient float64 + // The burstiness of the MaxGetChunkOpsPerSecondClient rate limiter. This is the maximum burst size that happen + // within a short time window. Default is 8. + GetChunkOpsBurstinessClient int + + // MaxGetChunkBytesPerSecondClient is the maximum bandwidth, in bytes, that GetChunk operations are permitted + // to consume per second. Default is 2MiB/s. + MaxGetChunkBytesPerSecondClient float64 + // The burstiness of the MaxGetChunkBytesPerSecondClient rate limiter. This is the maximum burst size that happen + // within a short time window. Default is 2MiB. + GetChunkBytesBurstinessClient int + + // MaxConcurrentGetChunkOpsClient is the maximum number of concurrent GetChunk operations that are permitted. + // Default is 1. + MaxConcurrentGetChunkOpsClient int +} diff --git a/relay/limiter/limiter_test.go b/relay/limiter/limiter_test.go new file mode 100644 index 0000000000..6064220f40 --- /dev/null +++ b/relay/limiter/limiter_test.go @@ -0,0 +1,86 @@ +package limiter + +import ( + "github.com/stretchr/testify/require" + "golang.org/x/time/rate" + "testing" + "time" +) + +// The rate.Limiter library has less documentation than ideal. Although I can figure out what it's doing by reading +// the code, I think it's risky writing things that depend on what may change in the future. In these tests, I verify +// some basic properties of the rate.Limiter library, so that if these properties ever change in the future, the tests +// will fail and we'll know to update the code. + +func TestPositiveTokens(t *testing.T) { + configuredRate := rate.Limit(10.0) + // "burst" is equivalent to the bucket size, aka the number of tokens that can be stored + configuredBurst := 10 + + // time starts at current time, but advances manually afterward + now := time.Now() + + rateLimiter := rate.NewLimiter(configuredRate, configuredBurst) + + // number of tokens should equal the burst limit + require.Equal(t, configuredBurst, int(rateLimiter.TokensAt(now))) + + // moving forward in time should not change the number of tokens + now = now.Add(time.Second) + require.Equal(t, configuredBurst, int(rateLimiter.TokensAt(now))) + + // remove each token without advancing time + for i := 0; i < configuredBurst; i++ { + require.True(t, rateLimiter.AllowN(now, 1)) + require.Equal(t, configuredBurst-i-1, int(rateLimiter.TokensAt(now))) + } + require.Equal(t, 0, int(rateLimiter.TokensAt(now))) + + // removing an additional token should fail + require.False(t, rateLimiter.AllowN(now, 1)) + require.Equal(t, 0, int(rateLimiter.TokensAt(now))) + + // tokens should return at a rate of once per 100ms + for i := 0; i < configuredBurst; i++ { + now = now.Add(100 * time.Millisecond) + require.Equal(t, i+1, int(rateLimiter.TokensAt(now))) + } + require.Equal(t, configuredBurst, int(rateLimiter.TokensAt(now))) + + // remove 7 tokens all at once + require.True(t, rateLimiter.AllowN(now, 7)) + require.Equal(t, 3, int(rateLimiter.TokensAt(now))) + + // move forward 500ms, returning 5 tokens + now = now.Add(500 * time.Millisecond) + require.Equal(t, 8, int(rateLimiter.TokensAt(now))) + + // try to take more than the burst limit + require.False(t, rateLimiter.AllowN(now, 100)) +} + +func TestNegativeTokens(t *testing.T) { + configuredRate := rate.Limit(10.0) + // "burst" is equivalent to the bucket size, aka the number of tokens that can be stored + configuredBurst := 10 + + // time starts at current time, but advances manually afterward + now := time.Now() + + rateLimiter := rate.NewLimiter(configuredRate, configuredBurst) + + // number of tokens should equal the burst limit + require.Equal(t, configuredBurst, int(rateLimiter.TokensAt(now))) + + // remove all tokens then add them back + require.True(t, rateLimiter.AllowN(now, configuredBurst)) + require.Equal(t, 0, int(rateLimiter.TokensAt(now))) + for i := 0; i < configuredBurst; i++ { + require.True(t, rateLimiter.AllowN(now, -1)) + require.Equal(t, i+1, int(rateLimiter.TokensAt(now))) + } + + // nothing funky should happen when time advances + now = now.Add(100 * time.Second) + require.Equal(t, configuredBurst, int(rateLimiter.TokensAt(now))) +} diff --git a/relay/metadata_provider.go b/relay/metadata_provider.go index f5b583d59f..3e32924072 100644 --- a/relay/metadata_provider.go +++ b/relay/metadata_provider.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/Layr-Labs/eigenda/core/v2" "github.com/Layr-Labs/eigenda/disperser/common/v2/blobstore" + "github.com/Layr-Labs/eigenda/encoding" "github.com/Layr-Labs/eigenda/relay/cache" "github.com/Layr-Labs/eigensdk-go/logging" "sync/atomic" @@ -15,6 +16,8 @@ import ( type blobMetadata struct { // the size of the blob in bytes blobSizeBytes uint32 + // the size of each encoded chunk + chunkSizeBytes uint32 // the size of the file containing the encoded chunks totalChunkSizeBytes uint32 // the fragment size used for uploading the encoded chunks @@ -153,8 +156,17 @@ func (m *metadataProvider) fetchMetadata(key v2.BlobKey) (*blobMetadata, error) } } + // TODO(cody-littley): blob size is not correct https://github.com/Layr-Labs/eigenda/pull/906#discussion_r1847396530 + blobSize := uint32(cert.BlobHeader.BlobCommitments.Length) + chunkSize, err := v2.GetChunkLength(cert.BlobHeader.BlobVersion, blobSize) + chunkSize *= encoding.BYTES_PER_SYMBOL + if err != nil { + return nil, fmt.Errorf("error getting chunk length: %w", err) + } + metadata := &blobMetadata{ - blobSizeBytes: 0, /* Future work: populate this once it is added to the metadata store */ + blobSizeBytes: blobSize, + chunkSizeBytes: chunkSize, totalChunkSizeBytes: fragmentInfo.TotalChunkSizeBytes, fragmentSizeBytes: fragmentInfo.FragmentSizeBytes, } diff --git a/relay/relay_test_utils.go b/relay/relay_test_utils.go index 0f5fdf1cf9..f850b65cc7 100644 --- a/relay/relay_test_utils.go +++ b/relay/relay_test_utils.go @@ -178,7 +178,7 @@ func buildChunkStore(t *testing.T, logger logging.Logger) (chunkstore.ChunkReade func randomBlob(t *testing.T) (*v2.BlobHeader, []byte) { - data := tu.RandomBytes(128) + data := tu.RandomBytes(225) // TODO talk to Ian about this data = codec.ConvertByPaddingEmptyByte(data) commitments, err := prover.GetCommitments(data) diff --git a/relay/server.go b/relay/server.go index c599c3e335..ad6072b9fe 100644 --- a/relay/server.go +++ b/relay/server.go @@ -4,8 +4,6 @@ import ( "context" "errors" "fmt" - "net" - pb "github.com/Layr-Labs/eigenda/api/grpc/relay" "github.com/Layr-Labs/eigenda/common/healthcheck" "github.com/Layr-Labs/eigenda/core" @@ -13,9 +11,12 @@ import ( "github.com/Layr-Labs/eigenda/disperser/common/v2/blobstore" "github.com/Layr-Labs/eigenda/encoding" "github.com/Layr-Labs/eigenda/relay/chunkstore" + "github.com/Layr-Labs/eigenda/relay/limiter" "github.com/Layr-Labs/eigensdk-go/logging" "google.golang.org/grpc" "google.golang.org/grpc/reflection" + "net" + "time" ) var _ pb.RelayServer = &Server{} @@ -24,11 +25,12 @@ var _ pb.RelayServer = &Server{} type Server struct { pb.UnimplementedRelayServer + // config is the configuration for the relay Server. + config *Config + // the logger for the server logger logging.Logger - config *Config - // metadataProvider encapsulates logic for fetching metadata for blobs. metadataProvider *metadataProvider @@ -38,32 +40,53 @@ type Server struct { // chunkProvider encapsulates logic for fetching chunks. chunkProvider *chunkProvider + // blobRateLimiter enforces rate limits on GetBlob and operations. + blobRateLimiter *limiter.BlobRateLimiter + + // chunkRateLimiter enforces rate limits on GetChunk operations. + chunkRateLimiter *limiter.ChunkRateLimiter + // grpcServer is the gRPC server. grpcServer *grpc.Server } type Config struct { + + // RelayIDs contains the IDs of the relays that this server is willing to serve data for. If empty, the server will + // serve data for any shard it can. + RelayIDs []v2.RelayKey + // GRPCPort is the port that the relay server listens on. GRPCPort int + // MaxGRPCMessageSize is the maximum size of a gRPC message that the server will accept. MaxGRPCMessageSize int + // MetadataCacheSize is the maximum number of items in the metadata cache. MetadataCacheSize int + // MetadataMaxConcurrency puts a limit on the maximum number of concurrent metadata fetches actively running on // goroutines. MetadataMaxConcurrency int + // BlobCacheSize is the maximum number of items in the blob cache. BlobCacheSize int + // BlobMaxConcurrency puts a limit on the maximum number of concurrent blob fetches actively running on goroutines. BlobMaxConcurrency int + // ChunkCacheSize is the maximum number of items in the chunk cache. ChunkCacheSize int + // ChunkMaxConcurrency is the size of the work pool for fetching chunks. Note that this does not // impact concurrency utilized by the s3 client to upload/download fragmented files. ChunkMaxConcurrency int - // RelayIDs contains the IDs of the relays that this server is willing to serve data for. If empty, the server will - // serve data for any shard it can. - RelayIDs []v2.RelayKey + + // MaxKeysPerGetChunksRequest is the maximum number of keys that can be requested in a single GetChunks request. + MaxKeysPerGetChunksRequest int + + // RateLimits contains configuration for rate limiting. + RateLimits limiter.Config } // NewServer creates a new relay Server. @@ -107,22 +130,28 @@ func NewServer( } return &Server{ - logger: logger, config: config, + logger: logger, metadataProvider: mp, blobProvider: bp, chunkProvider: cp, + blobRateLimiter: limiter.NewBlobRateLimiter(&config.RateLimits), + chunkRateLimiter: limiter.NewChunkRateLimiter(&config.RateLimits), }, nil } // GetBlob retrieves a blob stored by the relay. func (s *Server) GetBlob(ctx context.Context, request *pb.GetBlobRequest) (*pb.GetBlobReply, error) { - // Future work : - // - global throttle - // - per-connection throttle + // TODO(cody-littley): // - timeouts + err := s.blobRateLimiter.BeginGetBlobOperation(time.Now()) + if err != nil { + return nil, err + } + defer s.blobRateLimiter.FinishGetBlobOperation() + key, err := v2.BytesToBlobKey(request.BlobKey) if err != nil { return nil, fmt.Errorf("invalid blob key: %w", err) @@ -139,6 +168,11 @@ func (s *Server) GetBlob(ctx context.Context, request *pb.GetBlobRequest) (*pb.G return nil, fmt.Errorf("blob not found") } + err = s.blobRateLimiter.RequestGetBlobBandwidth(time.Now(), metadata.blobSizeBytes) + if err != nil { + return nil, err + } + data, err := s.blobProvider.GetBlob(key) if err != nil { return nil, fmt.Errorf("error fetching blob %s: %w", key.Hex(), err) @@ -154,16 +188,63 @@ func (s *Server) GetBlob(ctx context.Context, request *pb.GetBlobRequest) (*pb.G // GetChunks retrieves chunks from blobs stored by the relay. func (s *Server) GetChunks(ctx context.Context, request *pb.GetChunksRequest) (*pb.GetChunksReply, error) { - // Future work: + // TODO(cody-littley): // - authentication - // - global throttle - // - per-connection throttle // - timeouts if len(request.ChunkRequests) <= 0 { return nil, fmt.Errorf("no chunk requests provided") } + if len(request.ChunkRequests) > s.config.MaxKeysPerGetChunksRequest { + return nil, fmt.Errorf( + "too many chunk requests provided, max is %d", s.config.MaxKeysPerGetChunksRequest) + } + + // Future work: client IDs will be fixed when authentication is implemented + clientID := fmt.Sprintf("%d", request.RequesterId) + err := s.chunkRateLimiter.BeginGetChunkOperation(time.Now(), clientID) + if err != nil { + return nil, err + } + defer s.chunkRateLimiter.FinishGetChunkOperation(clientID) + + keys, err := getKeysFromChunkRequest(request) + if err != nil { + return nil, err + } + + mMap, err := s.metadataProvider.GetMetadataForBlobs(keys) + if err != nil { + return nil, fmt.Errorf( + "error fetching metadata for blob, check if blob exists and is assigned to this relay: %w", err) + } + + requiredBandwidth, err := computeChunkRequestRequiredBandwidth(request, mMap) + if err != nil { + return nil, fmt.Errorf("error computing required bandwidth: %w", err) + } + err = s.chunkRateLimiter.RequestGetChunkBandwidth(time.Now(), clientID, requiredBandwidth) + if err != nil { + return nil, err + } + + frames, err := s.chunkProvider.GetFrames(ctx, mMap) + if err != nil { + return nil, fmt.Errorf("error fetching frames: %w", err) + } + bytesToSend, err := gatherChunkDataToSend(frames, request) + if err != nil { + return nil, fmt.Errorf("error gathering chunk data: %w", err) + } + + return &pb.GetChunksReply{ + Data: bytesToSend, + }, nil +} + +// getKeysFromChunkRequest gathers a slice of blob keys from a GetChunks request. +func getKeysFromChunkRequest(request *pb.GetChunksRequest) ([]v2.BlobKey, error) { keys := make([]v2.BlobKey, 0, len(request.ChunkRequests)) for _, chunkRequest := range request.ChunkRequests { @@ -184,20 +265,16 @@ func (s *Server) GetChunks(ctx context.Context, request *pb.GetChunksRequest) (* keys = append(keys, key) } - mMap, err := s.metadataProvider.GetMetadataForBlobs(keys) - if err != nil { - return nil, fmt.Errorf( - "error fetching metadata for blob, check if blob exists and is assigned to this relay: %w", err) - } + return keys, nil +} - frames, err := s.chunkProvider.GetFrames(ctx, mMap) - if err != nil { - return nil, fmt.Errorf("error fetching frames: %w", err) - } +// gatherChunkDataToSend takes the chunk data and narrows it down to the data requested in the GetChunks request. +func gatherChunkDataToSend( + frames map[v2.BlobKey][]*encoding.Frame, + request *pb.GetChunksRequest) ([][]byte, error) { - bytesToSend := make([][]byte, 0, len(keys)) + bytesToSend := make([][]byte, 0, len(frames)) - // return data in the order that it was requested for _, chunkRequest := range request.ChunkRequests { framesToSend := make([]*encoding.Frame, 0) @@ -246,14 +323,40 @@ func (s *Server) GetChunks(ctx context.Context, request *pb.GetChunksRequest) (* bytesToSend = append(bytesToSend, bundleBytes) } - return &pb.GetChunksReply{ - Data: bytesToSend, - }, nil + return bytesToSend, nil +} + +// computeChunkRequestRequiredBandwidth computes the bandwidth required to fulfill a GetChunks request. +func computeChunkRequestRequiredBandwidth(request *pb.GetChunksRequest, mMap metadataMap) (int, error) { + requiredBandwidth := 0 + for _, req := range request.ChunkRequests { + var metadata *blobMetadata + var key v2.BlobKey + var requestedChunks int + + if req.GetByIndex() != nil { + key = v2.BlobKey(req.GetByIndex().GetBlobKey()) + metadata = mMap[key] + requestedChunks = len(req.GetByIndex().ChunkIndices) + } else { + key = v2.BlobKey(req.GetByRange().GetBlobKey()) + metadata = mMap[key] + requestedChunks = int(req.GetByRange().EndIndex - req.GetByRange().StartIndex) + } + + if metadata == nil { + return 0, fmt.Errorf("metadata not found for key %s", key.Hex()) + } + + requiredBandwidth += requestedChunks * int(metadata.chunkSizeBytes) + } + + return requiredBandwidth, nil + } // Start starts the server listening for requests. This method will block until the server is stopped. func (s *Server) Start() error { - // Serve grpc requests addr := fmt.Sprintf("0.0.0.0:%d", s.config.GRPCPort) listener, err := net.Listen("tcp", addr) diff --git a/relay/server_test.go b/relay/server_test.go index d480349067..cedfa6ddb4 100644 --- a/relay/server_test.go +++ b/relay/server_test.go @@ -2,6 +2,7 @@ package relay import ( "context" + "github.com/Layr-Labs/eigenda/relay/limiter" "math/rand" "testing" @@ -18,14 +19,32 @@ import ( func defaultConfig() *Config { return &Config{ - GRPCPort: 50051, - MaxGRPCMessageSize: 1024 * 1024 * 300, - MetadataCacheSize: 1024 * 1024, - MetadataMaxConcurrency: 32, - BlobCacheSize: 32, - BlobMaxConcurrency: 32, - ChunkCacheSize: 32, - ChunkMaxConcurrency: 32, + GRPCPort: 50051, + MaxGRPCMessageSize: 1024 * 1024 * 300, + MetadataCacheSize: 1024 * 1024, + MetadataMaxConcurrency: 32, + BlobCacheSize: 32, + BlobMaxConcurrency: 32, + ChunkCacheSize: 32, + ChunkMaxConcurrency: 32, + MaxKeysPerGetChunksRequest: 1024, + RateLimits: limiter.Config{ + MaxGetBlobOpsPerSecond: 1024, + GetBlobOpsBurstiness: 1024, + MaxGetBlobBytesPerSecond: 20 * 1024 * 1024, + GetBlobBytesBurstiness: 20 * 1024 * 1024, + MaxConcurrentGetBlobOps: 1024, + MaxGetChunkOpsPerSecond: 1024, + GetChunkOpsBurstiness: 1024, + MaxGetChunkBytesPerSecond: 20 * 1024 * 1024, + GetChunkBytesBurstiness: 20 * 1024 * 1024, + MaxConcurrentGetChunkOps: 1024, + MaxGetChunkOpsPerSecondClient: 8, + GetChunkOpsBurstinessClient: 8, + MaxGetChunkBytesPerSecondClient: 2 * 1024 * 1024, + GetChunkBytesBurstinessClient: 2 * 1024 * 1024, + MaxConcurrentGetChunkOpsClient: 1, + }, } } @@ -318,6 +337,10 @@ func TestReadWriteChunks(t *testing.T) { // This is the server used to read it back config := defaultConfig() + config.RateLimits.MaxGetChunkOpsPerSecond = 1000 + config.RateLimits.GetChunkOpsBurstiness = 1000 + config.RateLimits.MaxGetChunkOpsPerSecondClient = 1000 + config.RateLimits.GetChunkOpsBurstinessClient = 1000 server, err := NewServer( context.Background(), logger, @@ -634,6 +657,10 @@ func TestReadWriteChunksWithSharding(t *testing.T) { // This is the server used to read it back config := defaultConfig() config.RelayIDs = shardList + config.RateLimits.MaxGetChunkOpsPerSecond = 1000 + config.RateLimits.GetChunkOpsBurstiness = 1000 + config.RateLimits.MaxGetChunkOpsPerSecondClient = 1000 + config.RateLimits.GetChunkOpsBurstinessClient = 1000 server, err := NewServer( context.Background(), logger, @@ -904,6 +931,10 @@ func TestBatchedReadWriteChunksWithSharding(t *testing.T) { // This is the server used to read it back config := defaultConfig() config.RelayIDs = shardList + config.RateLimits.MaxGetChunkOpsPerSecond = 1000 + config.RateLimits.GetChunkOpsBurstiness = 1000 + config.RateLimits.MaxGetChunkOpsPerSecondClient = 1000 + config.RateLimits.GetChunkOpsBurstinessClient = 1000 server, err := NewServer( context.Background(), logger,