From 861602bee1f923eb3070126a7ebefa509280b966 Mon Sep 17 00:00:00 2001 From: Ian Shim Date: Fri, 19 Apr 2024 16:27:36 -0700 Subject: [PATCH] read allowlist from json file --- disperser/apiserver/config.go | 163 +++++++++++++++++++------- disperser/apiserver/ratelimit_test.go | 2 +- disperser/apiserver/server.go | 10 +- disperser/apiserver/server_test.go | 81 ++++++++++++- 4 files changed, 205 insertions(+), 51 deletions(-) diff --git a/disperser/apiserver/config.go b/disperser/apiserver/config.go index a3b3bd3e13..0cd8338a65 100644 --- a/disperser/apiserver/config.go +++ b/disperser/apiserver/config.go @@ -1,8 +1,11 @@ package apiserver import ( + "encoding/json" "errors" + "io" "log" + "os" "strconv" "strings" @@ -19,6 +22,7 @@ const ( PerUserUnauthBlobRateFlagName = "auth.per-user-unauth-blob-rate" ClientIPHeaderFlagName = "auth.client-ip-header" AllowlistFlagName = "auth.allowlist" + AllowlistFileFlagName = "auth.allowlist-file" RetrievalBlobRateFlagName = "auth.retrieval-blob-rate" RetrievalThroughputFlagName = "auth.retrieval-throughput" @@ -36,12 +40,21 @@ type QuorumRateInfo struct { } type PerUserRateInfo struct { + Name string Throughput common.RateParam BlobRate common.RateParam } type Allowlist = map[string]map[core.QuorumID]PerUserRateInfo +type AllowlistEntry struct { + Name string `json:"name"` + Account string `json:"account"` + QuorumID uint8 `json:"quorumID"` + BlobRate float64 `json:"blobRate"` + ByteRate float64 `json:"byteRate"` +} + type RateConfig struct { QuorumRateInfos map[core.QuorumID]QuorumRateInfo ClientIPHeader string @@ -51,6 +64,25 @@ type RateConfig struct { RetrievalThroughput common.RateParam } +// Deprecated: use AllowlistFileFlagName instead +func AllowlistFlag(envPrefix string) cli.Flag { + return cli.StringSliceFlag{ + Name: AllowlistFlagName, + Usage: "Allowlist of IPs or ethereum addresses (including initial \"0x\") and corresponding blob/byte rates to bypass rate limiting. Format: [||]///. Example: 127.0.0.1/0/10/10485760", + EnvVar: common.PrefixEnvVar(envPrefix, "ALLOWLIST"), + Required: false, + Value: &cli.StringSlice{}, + } +} +func AllowlistFileFlag(envPrefix string) cli.Flag { + return cli.StringFlag{ + Name: AllowlistFileFlagName, + Usage: "Path to a file containing the allowlist of IPs or ethereum addresses (including initial \"0x\") and corresponding blob/byte rates to bypass rate limiting. This file must be in JSON format", + EnvVar: common.PrefixEnvVar(envPrefix, "ALLOWLIST_FILE"), + Required: false, + } +} + func CLIFlags(envPrefix string) []cli.Flag { return []cli.Flag{ cli.IntSliceFlag{ @@ -90,13 +122,8 @@ func CLIFlags(envPrefix string) []cli.Flag { Value: "", EnvVar: common.PrefixEnvVar(envPrefix, "CLIENT_IP_HEADER"), }, - cli.StringSliceFlag{ - Name: AllowlistFlagName, - Usage: "Allowlist of IPs or ethereum addresses (including initial \"0x\") and corresponding blob/byte rates to bypass rate limiting. Format: [||]///. Example: 127.0.0.1/0/10/10485760", - EnvVar: common.PrefixEnvVar(envPrefix, "ALLOWLIST"), - Required: false, - Value: &cli.StringSlice{}, - }, + AllowlistFlag(envPrefix), + AllowlistFileFlag(envPrefix), cli.IntFlag{ Name: RetrievalBlobRateFlagName, Usage: "The blob rate limit for retrieval requests (Blobs/sec)", @@ -112,43 +139,9 @@ func CLIFlags(envPrefix string) []cli.Flag { } } -func ReadCLIConfig(c *cli.Context) (RateConfig, error) { - - numQuorums := len(c.IntSlice(RegisteredQuorumFlagName)) - if len(c.StringSlice(TotalUnauthBlobRateFlagName)) != numQuorums { - return RateConfig{}, errors.New("number of total unauth blob rates does not match number of quorums") - } - if len(c.StringSlice(PerUserUnauthBlobRateFlagName)) != numQuorums { - return RateConfig{}, errors.New("number of per user unauth blob intervals does not match number of quorums") - } - if len(c.IntSlice(TotalUnauthThroughputFlagName)) != numQuorums { - return RateConfig{}, errors.New("number of total unauth throughput does not match number of quorums") - } - if len(c.IntSlice(PerUserUnauthThroughputFlagName)) != numQuorums { - return RateConfig{}, errors.New("number of per user unauth throughput does not match number of quorums") - } - - quorumRateInfos := make(map[core.QuorumID]QuorumRateInfo) - for ind, quorumID := range c.IntSlice(RegisteredQuorumFlagName) { - - totalBlobRate, err := strconv.ParseFloat(c.StringSlice(TotalUnauthBlobRateFlagName)[ind], 64) - if err != nil { - return RateConfig{}, err - } - accountBlobRate, err := strconv.ParseFloat(c.StringSlice(PerUserUnauthBlobRateFlagName)[ind], 64) - if err != nil { - return RateConfig{}, err - } - - quorumRateInfos[core.QuorumID(quorumID)] = QuorumRateInfo{ - TotalUnauthThroughput: common.RateParam(c.IntSlice(TotalUnauthThroughputFlagName)[ind]), - PerUserUnauthThroughput: common.RateParam(c.IntSlice(PerUserUnauthThroughputFlagName)[ind]), - TotalUnauthBlobRate: common.RateParam(totalBlobRate * blobRateMultiplier), - PerUserUnauthBlobRate: common.RateParam(accountBlobRate * blobRateMultiplier), - } - } - - // Parse allowlist +func parseAllowlistEntry(c *cli.Context) Allowlist { + // Parse from AllowlistFlagName + // Remove when AllowlistFlagName is deprecated and no longer used allowlist := make(Allowlist) for _, allowlistEntry := range c.StringSlice(AllowlistFlagName) { allowlistEntrySplit := strings.Split(allowlistEntry, "/") @@ -188,6 +181,88 @@ func ReadCLIConfig(c *cli.Context) (RateConfig, error) { } } + // Parse from AllowlistFileFlagName + allowlistFileName := c.String(AllowlistFileFlagName) + if allowlistFileName != "" { + allowlistFile, err := os.Open(allowlistFileName) + if err != nil { + log.Printf("failed to read allowlist file: %s", err) + return allowlist + } + defer allowlistFile.Close() + var allowlistEntries []AllowlistEntry + content, err := io.ReadAll(allowlistFile) + if err != nil { + log.Printf("failed to load allowlist file content: %s", err) + return allowlist + } + err = json.Unmarshal(content, &allowlistEntries) + if err != nil { + log.Printf("failed to parse allowlist file content: %s", err) + return allowlist + } + + for _, entry := range allowlistEntries { + rateInfoByQuorum, ok := allowlist[entry.Account] + if !ok { + allowlist[entry.Account] = map[core.QuorumID]PerUserRateInfo{ + core.QuorumID(entry.QuorumID): { + Name: entry.Name, + Throughput: common.RateParam(entry.ByteRate), + BlobRate: common.RateParam(entry.BlobRate * blobRateMultiplier), + }, + } + } else { + rateInfoByQuorum[core.QuorumID(entry.QuorumID)] = PerUserRateInfo{ + Name: entry.Name, + Throughput: common.RateParam(entry.ByteRate), + BlobRate: common.RateParam(entry.BlobRate * blobRateMultiplier), + } + } + } + } + + return allowlist +} + +func ReadCLIConfig(c *cli.Context) (RateConfig, error) { + + numQuorums := len(c.IntSlice(RegisteredQuorumFlagName)) + if len(c.StringSlice(TotalUnauthBlobRateFlagName)) != numQuorums { + return RateConfig{}, errors.New("number of total unauth blob rates does not match number of quorums") + } + if len(c.StringSlice(PerUserUnauthBlobRateFlagName)) != numQuorums { + return RateConfig{}, errors.New("number of per user unauth blob intervals does not match number of quorums") + } + if len(c.IntSlice(TotalUnauthThroughputFlagName)) != numQuorums { + return RateConfig{}, errors.New("number of total unauth throughput does not match number of quorums") + } + if len(c.IntSlice(PerUserUnauthThroughputFlagName)) != numQuorums { + return RateConfig{}, errors.New("number of per user unauth throughput does not match number of quorums") + } + + quorumRateInfos := make(map[core.QuorumID]QuorumRateInfo) + for ind, quorumID := range c.IntSlice(RegisteredQuorumFlagName) { + + totalBlobRate, err := strconv.ParseFloat(c.StringSlice(TotalUnauthBlobRateFlagName)[ind], 64) + if err != nil { + return RateConfig{}, err + } + accountBlobRate, err := strconv.ParseFloat(c.StringSlice(PerUserUnauthBlobRateFlagName)[ind], 64) + if err != nil { + return RateConfig{}, err + } + + quorumRateInfos[core.QuorumID(quorumID)] = QuorumRateInfo{ + TotalUnauthThroughput: common.RateParam(c.IntSlice(TotalUnauthThroughputFlagName)[ind]), + PerUserUnauthThroughput: common.RateParam(c.IntSlice(PerUserUnauthThroughputFlagName)[ind]), + TotalUnauthBlobRate: common.RateParam(totalBlobRate * blobRateMultiplier), + PerUserUnauthBlobRate: common.RateParam(accountBlobRate * blobRateMultiplier), + } + } + + allowlist := parseAllowlistEntry(c) + return RateConfig{ QuorumRateInfos: quorumRateInfos, ClientIPHeader: c.String(ClientIPHeaderFlagName), diff --git a/disperser/apiserver/ratelimit_test.go b/disperser/apiserver/ratelimit_test.go index 86fef27c19..8a7941d663 100644 --- a/disperser/apiserver/ratelimit_test.go +++ b/disperser/apiserver/ratelimit_test.go @@ -198,7 +198,7 @@ func TestRetrievalRateLimit(t *testing.T) { numLimited := 0 tt := time.Now() for i := 0; i < 15; i++ { - _, err = retrieveBlob(t, dispersalServer, requestID, 1) + _, err = retrieveBlob(dispersalServer, requestID, 1) fmt.Println(time.Since(tt)) tt = time.Now() if err != nil && strings.Contains(err.Error(), "request ratelimited: Retrieval blob rate limit") { diff --git a/disperser/apiserver/server.go b/disperser/apiserver/server.go index 4a04a939bf..97d9ec458b 100644 --- a/disperser/apiserver/server.go +++ b/disperser/apiserver/server.go @@ -72,9 +72,9 @@ func NewDispersalServer( rateConfig RateConfig, ) *DispersalServer { logger := _logger.With("component", "DispersalServer") - for ip, rateInfoByQuorum := range rateConfig.Allowlist { + for account, rateInfoByQuorum := range rateConfig.Allowlist { for quorumID, rateInfo := range rateInfoByQuorum { - logger.Info("[Allowlist]", "ip", ip, "quorumID", quorumID, "throughput", rateInfo.Throughput, "blobRate", rateInfo.BlobRate) + logger.Info("[Allowlist]", "account", account, "name", rateInfo.Name, "quorumID", quorumID, "throughput", rateInfo.Throughput, "blobRate", rateInfo.BlobRate) } } @@ -306,6 +306,7 @@ func (s *DispersalServer) getAccountRate(origin, authenticatedAddress string, qu } rates := &PerUserRateInfo{ + Name: "", Throughput: unauthRates.PerUserUnauthThroughput, BlobRate: unauthRates.PerUserUnauthBlobRate, } @@ -323,6 +324,7 @@ func (s *DispersalServer) getAccountRate(origin, authenticatedAddress string, qu if rateInfo.BlobRate > 0 { rates.BlobRate = rateInfo.BlobRate } + rates.Name = rateInfo.Name return rates, key, nil } } @@ -429,7 +431,7 @@ func (s *DispersalServer) checkRateLimitsAndAddRatesToHeader(ctx context.Context blobSize := len(blob.Data) length := encoding.GetBlobLength(uint(blobSize)) - + requesterName := "" for i, param := range blob.RequestHeader.SecurityParams { globalRates, ok := s.rateConfig.QuorumRateInfos[param.QuorumID] @@ -443,6 +445,7 @@ func (s *DispersalServer) checkRateLimitsAndAddRatesToHeader(ctx context.Context s.metrics.HandleInternalFailureRpcRequest(apiMethodName) return api.NewInternalError(err.Error()) } + requesterName = accountRates.Name // Update the quorum rate blob.RequestHeader.SecurityParams[i].QuorumRate = accountRates.Throughput @@ -521,6 +524,7 @@ func (s *DispersalServer) checkRateLimitsAndAddRatesToHeader(ctx context.Context } else if info.RateType == AccountThroughputType || info.RateType == AccountBlobRateType { s.metrics.HandleAccountRateLimitedRpcRequest(apiMethodName) s.metrics.HandleAccountRateLimitedRequest(fmt.Sprint(info.QuorumID), blobSize, apiMethodName) + s.logger.Info("request ratelimited", "requesterName", requesterName, "requesterID", params.RequesterID, "rateType", info.RateType.String(), "quorum", info.QuorumID) } errorString := fmt.Sprintf("request ratelimited: %s for quorum %d", info.RateType.String(), info.QuorumID) return api.NewResourceExhaustedError(errorString) diff --git a/disperser/apiserver/server_test.go b/disperser/apiserver/server_test.go index 7306a50c82..2d49639ee7 100644 --- a/disperser/apiserver/server_test.go +++ b/disperser/apiserver/server_test.go @@ -3,6 +3,7 @@ package apiserver_test import ( "context" "crypto/rand" + "flag" "fmt" "net" "os" @@ -19,6 +20,7 @@ import ( gethcommon "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" "github.com/google/uuid" + "github.com/urfave/cli" pb "github.com/Layr-Labs/eigenda/api/grpc/disperser" "github.com/Layr-Labs/eigenda/common" @@ -362,7 +364,7 @@ func TestRetrieveBlob(t *testing.T) { assert.Equal(t, reply.GetStatus(), pb.BlobStatus_CONFIRMED) // Retrieve the blob and compare it with the original data - retrieveData, err := retrieveBlob(t, dispersalServer, requestID, 1) + retrieveData, err := retrieveBlob(dispersalServer, requestID, 1) assert.NoError(t, err) assert.Equal(t, data, retrieveData) @@ -390,7 +392,7 @@ func TestRetrieveBlobFailsWhenBlobNotConfirmed(t *testing.T) { assert.Equal(t, reply.GetStatus(), pb.BlobStatus_PROCESSING) // Try to retrieve the blob before it is confirmed - _, err = retrieveBlob(t, dispersalServer, requestID, 2) + _, err = retrieveBlob(dispersalServer, requestID, 2) assert.NotNil(t, err) assert.Equal(t, "rpc error: code = NotFound desc = no metadata found for the given batch header hash and blob index", err.Error()) @@ -419,6 +421,79 @@ func TestDisperseBlobWithExceedSizeLimit(t *testing.T) { assert.Equal(t, err.Error(), "rpc error: code = InvalidArgument desc = blob size cannot exceed 2 MiB") } +func TestParseAllowlist(t *testing.T) { + fs := flag.NewFlagSet("disperser", flag.ContinueOnError) + allowlistFlag := apiserver.AllowlistFlag("disperser") + allowlistFlag.Apply(fs) + allowlistFileFlag := apiserver.AllowlistFileFlag("disperser") + allowlistFileFlag.Apply(fs) + fs.Parse([]string{"--auth.allowlist", "52.202.222.39/0/100/52428800"}) + fs.Parse([]string{"--auth.allowlist", "52.202.222.39/1/100/52428800"}) + fs.Parse([]string{"--auth.allowlist", "3.225.189.232/0/1/1024"}) + + f, err := os.CreateTemp("", "allowlist.*.json") + assert.NoError(t, err) + defer os.Remove(f.Name()) + _, err = f.WriteString(` +[ + { + "name": "eigenlabs", + "account": "0.1.2.3", + "quorumID": 0, + "blobRate": 0.01, + "byteRate": 1024 + }, + { + "name": "eigenlabs", + "account": "0.1.2.3", + "quorumID": 1, + "blobRate": 1, + "byteRate": 1048576 + }, + { + "name": "foo", + "account": "5.5.5.5", + "quorumID": 1, + "blobRate": 0.1, + "byteRate": 4092 + } +] + `) + assert.NoError(t, err) + fs.Parse([]string{"--auth.allowlist-file", f.Name()}) + c := cli.NewContext(nil, fs, nil) + rateConfig, err := apiserver.ReadCLIConfig(c) + assert.NoError(t, err) + assert.Contains(t, rateConfig.Allowlist, "52.202.222.39") + assert.Contains(t, rateConfig.Allowlist, "3.225.189.232") + assert.Contains(t, rateConfig.Allowlist["52.202.222.39"], uint8(0)) + assert.Contains(t, rateConfig.Allowlist["52.202.222.39"], uint8(1)) + assert.Contains(t, rateConfig.Allowlist["3.225.189.232"], uint8(0)) + assert.NotContains(t, rateConfig.Allowlist["3.225.189.232"], uint8(1)) + assert.Equal(t, rateConfig.Allowlist["52.202.222.39"][0].BlobRate, uint32(100*1e6)) + assert.Equal(t, rateConfig.Allowlist["52.202.222.39"][0].Throughput, uint32(52428800)) + assert.Equal(t, rateConfig.Allowlist["52.202.222.39"][1].BlobRate, uint32(100*1e6)) + assert.Equal(t, rateConfig.Allowlist["52.202.222.39"][1].Throughput, uint32(52428800)) + assert.Equal(t, rateConfig.Allowlist["3.225.189.232"][0].BlobRate, uint32(1e6)) + assert.Equal(t, rateConfig.Allowlist["3.225.189.232"][0].Throughput, uint32(1024)) + + assert.Contains(t, rateConfig.Allowlist, "0.1.2.3") + assert.Contains(t, rateConfig.Allowlist, "5.5.5.5") + assert.Contains(t, rateConfig.Allowlist["0.1.2.3"], uint8(0)) + assert.Contains(t, rateConfig.Allowlist["0.1.2.3"], uint8(1)) + assert.Contains(t, rateConfig.Allowlist["5.5.5.5"], uint8(1)) + assert.NotContains(t, rateConfig.Allowlist["5.5.5.5"], uint8(0)) + assert.Equal(t, rateConfig.Allowlist["0.1.2.3"][0].Name, "eigenlabs") + assert.Equal(t, rateConfig.Allowlist["0.1.2.3"][0].BlobRate, uint32(0.01*1e6)) + assert.Equal(t, rateConfig.Allowlist["0.1.2.3"][0].Throughput, uint32(1024)) + assert.Equal(t, rateConfig.Allowlist["0.1.2.3"][1].Name, "eigenlabs") + assert.Equal(t, rateConfig.Allowlist["0.1.2.3"][1].BlobRate, uint32(1e6)) + assert.Equal(t, rateConfig.Allowlist["0.1.2.3"][1].Throughput, uint32(1048576)) + assert.Equal(t, rateConfig.Allowlist["5.5.5.5"][1].Name, "foo") + assert.Equal(t, rateConfig.Allowlist["5.5.5.5"][1].BlobRate, uint32(0.1*1e6)) + assert.Equal(t, rateConfig.Allowlist["5.5.5.5"][1].Throughput, uint32(4092)) +} + func setup(m *testing.M) { deployLocalStack = !(os.Getenv("DEPLOY_LOCALSTACK") == "false") @@ -560,7 +635,7 @@ func disperseBlob(t *testing.T, server *apiserver.DispersalServer, data []byte) return reply.GetResult(), uint(len(data)), reply.GetRequestId() } -func retrieveBlob(t *testing.T, server *apiserver.DispersalServer, requestID []byte, blobIndex uint32) ([]byte, error) { +func retrieveBlob(server *apiserver.DispersalServer, requestID []byte, blobIndex uint32) ([]byte, error) { p := &peer.Peer{ Addr: &net.TCPAddr{ IP: net.ParseIP("0.0.0.0"),