Skip to content

Commit

Permalink
read allowlist from json file
Browse files Browse the repository at this point in the history
  • Loading branch information
ian-shim committed Apr 20, 2024
1 parent da23ea7 commit 861602b
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 51 deletions.
163 changes: 119 additions & 44 deletions disperser/apiserver/config.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package apiserver

import (
"encoding/json"
"errors"
"io"
"log"
"os"
"strconv"
"strings"

Expand All @@ -19,6 +22,7 @@ const (
PerUserUnauthBlobRateFlagName = "auth.per-user-unauth-blob-rate"
ClientIPHeaderFlagName = "auth.client-ip-header"
AllowlistFlagName = "auth.allowlist"
AllowlistFileFlagName = "auth.allowlist-file"

RetrievalBlobRateFlagName = "auth.retrieval-blob-rate"
RetrievalThroughputFlagName = "auth.retrieval-throughput"
Expand All @@ -36,12 +40,21 @@ type QuorumRateInfo struct {
}

type PerUserRateInfo struct {
Name string
Throughput common.RateParam
BlobRate common.RateParam
}

type Allowlist = map[string]map[core.QuorumID]PerUserRateInfo

type AllowlistEntry struct {
Name string `json:"name"`
Account string `json:"account"`
QuorumID uint8 `json:"quorumID"`
BlobRate float64 `json:"blobRate"`
ByteRate float64 `json:"byteRate"`
}

type RateConfig struct {
QuorumRateInfos map[core.QuorumID]QuorumRateInfo
ClientIPHeader string
Expand All @@ -51,6 +64,25 @@ type RateConfig struct {
RetrievalThroughput common.RateParam
}

// Deprecated: use AllowlistFileFlagName instead
func AllowlistFlag(envPrefix string) cli.Flag {
return cli.StringSliceFlag{
Name: AllowlistFlagName,
Usage: "Allowlist of IPs or ethereum addresses (including initial \"0x\") and corresponding blob/byte rates to bypass rate limiting. Format: [<IP>||<ETH ADDRESS>]/<quorum ID>/<blob rate>/<byte rate>. Example: 127.0.0.1/0/10/10485760",
EnvVar: common.PrefixEnvVar(envPrefix, "ALLOWLIST"),
Required: false,
Value: &cli.StringSlice{},
}
}
func AllowlistFileFlag(envPrefix string) cli.Flag {
return cli.StringFlag{
Name: AllowlistFileFlagName,
Usage: "Path to a file containing the allowlist of IPs or ethereum addresses (including initial \"0x\") and corresponding blob/byte rates to bypass rate limiting. This file must be in JSON format",
EnvVar: common.PrefixEnvVar(envPrefix, "ALLOWLIST_FILE"),
Required: false,
}
}

func CLIFlags(envPrefix string) []cli.Flag {
return []cli.Flag{
cli.IntSliceFlag{
Expand Down Expand Up @@ -90,13 +122,8 @@ func CLIFlags(envPrefix string) []cli.Flag {
Value: "",
EnvVar: common.PrefixEnvVar(envPrefix, "CLIENT_IP_HEADER"),
},
cli.StringSliceFlag{
Name: AllowlistFlagName,
Usage: "Allowlist of IPs or ethereum addresses (including initial \"0x\") and corresponding blob/byte rates to bypass rate limiting. Format: [<IP>||<ETH ADDRESS>]/<quorum ID>/<blob rate>/<byte rate>. Example: 127.0.0.1/0/10/10485760",
EnvVar: common.PrefixEnvVar(envPrefix, "ALLOWLIST"),
Required: false,
Value: &cli.StringSlice{},
},
AllowlistFlag(envPrefix),
AllowlistFileFlag(envPrefix),
cli.IntFlag{
Name: RetrievalBlobRateFlagName,
Usage: "The blob rate limit for retrieval requests (Blobs/sec)",
Expand All @@ -112,43 +139,9 @@ func CLIFlags(envPrefix string) []cli.Flag {
}
}

func ReadCLIConfig(c *cli.Context) (RateConfig, error) {

numQuorums := len(c.IntSlice(RegisteredQuorumFlagName))
if len(c.StringSlice(TotalUnauthBlobRateFlagName)) != numQuorums {
return RateConfig{}, errors.New("number of total unauth blob rates does not match number of quorums")
}
if len(c.StringSlice(PerUserUnauthBlobRateFlagName)) != numQuorums {
return RateConfig{}, errors.New("number of per user unauth blob intervals does not match number of quorums")
}
if len(c.IntSlice(TotalUnauthThroughputFlagName)) != numQuorums {
return RateConfig{}, errors.New("number of total unauth throughput does not match number of quorums")
}
if len(c.IntSlice(PerUserUnauthThroughputFlagName)) != numQuorums {
return RateConfig{}, errors.New("number of per user unauth throughput does not match number of quorums")
}

quorumRateInfos := make(map[core.QuorumID]QuorumRateInfo)
for ind, quorumID := range c.IntSlice(RegisteredQuorumFlagName) {

totalBlobRate, err := strconv.ParseFloat(c.StringSlice(TotalUnauthBlobRateFlagName)[ind], 64)
if err != nil {
return RateConfig{}, err
}
accountBlobRate, err := strconv.ParseFloat(c.StringSlice(PerUserUnauthBlobRateFlagName)[ind], 64)
if err != nil {
return RateConfig{}, err
}

quorumRateInfos[core.QuorumID(quorumID)] = QuorumRateInfo{
TotalUnauthThroughput: common.RateParam(c.IntSlice(TotalUnauthThroughputFlagName)[ind]),
PerUserUnauthThroughput: common.RateParam(c.IntSlice(PerUserUnauthThroughputFlagName)[ind]),
TotalUnauthBlobRate: common.RateParam(totalBlobRate * blobRateMultiplier),
PerUserUnauthBlobRate: common.RateParam(accountBlobRate * blobRateMultiplier),
}
}

// Parse allowlist
func parseAllowlistEntry(c *cli.Context) Allowlist {
// Parse from AllowlistFlagName
// Remove when AllowlistFlagName is deprecated and no longer used
allowlist := make(Allowlist)
for _, allowlistEntry := range c.StringSlice(AllowlistFlagName) {
allowlistEntrySplit := strings.Split(allowlistEntry, "/")
Expand Down Expand Up @@ -188,6 +181,88 @@ func ReadCLIConfig(c *cli.Context) (RateConfig, error) {
}
}

// Parse from AllowlistFileFlagName
allowlistFileName := c.String(AllowlistFileFlagName)
if allowlistFileName != "" {
allowlistFile, err := os.Open(allowlistFileName)
if err != nil {
log.Printf("failed to read allowlist file: %s", err)
return allowlist
}
defer allowlistFile.Close()
var allowlistEntries []AllowlistEntry
content, err := io.ReadAll(allowlistFile)
if err != nil {
log.Printf("failed to load allowlist file content: %s", err)
return allowlist
}
err = json.Unmarshal(content, &allowlistEntries)
if err != nil {
log.Printf("failed to parse allowlist file content: %s", err)
return allowlist
}

for _, entry := range allowlistEntries {
rateInfoByQuorum, ok := allowlist[entry.Account]
if !ok {
allowlist[entry.Account] = map[core.QuorumID]PerUserRateInfo{
core.QuorumID(entry.QuorumID): {
Name: entry.Name,
Throughput: common.RateParam(entry.ByteRate),
BlobRate: common.RateParam(entry.BlobRate * blobRateMultiplier),
},
}
} else {
rateInfoByQuorum[core.QuorumID(entry.QuorumID)] = PerUserRateInfo{
Name: entry.Name,
Throughput: common.RateParam(entry.ByteRate),
BlobRate: common.RateParam(entry.BlobRate * blobRateMultiplier),
}
}
}
}

return allowlist
}

func ReadCLIConfig(c *cli.Context) (RateConfig, error) {

numQuorums := len(c.IntSlice(RegisteredQuorumFlagName))
if len(c.StringSlice(TotalUnauthBlobRateFlagName)) != numQuorums {
return RateConfig{}, errors.New("number of total unauth blob rates does not match number of quorums")
}
if len(c.StringSlice(PerUserUnauthBlobRateFlagName)) != numQuorums {
return RateConfig{}, errors.New("number of per user unauth blob intervals does not match number of quorums")
}
if len(c.IntSlice(TotalUnauthThroughputFlagName)) != numQuorums {
return RateConfig{}, errors.New("number of total unauth throughput does not match number of quorums")
}
if len(c.IntSlice(PerUserUnauthThroughputFlagName)) != numQuorums {
return RateConfig{}, errors.New("number of per user unauth throughput does not match number of quorums")
}

quorumRateInfos := make(map[core.QuorumID]QuorumRateInfo)
for ind, quorumID := range c.IntSlice(RegisteredQuorumFlagName) {

totalBlobRate, err := strconv.ParseFloat(c.StringSlice(TotalUnauthBlobRateFlagName)[ind], 64)
if err != nil {
return RateConfig{}, err
}
accountBlobRate, err := strconv.ParseFloat(c.StringSlice(PerUserUnauthBlobRateFlagName)[ind], 64)
if err != nil {
return RateConfig{}, err
}

quorumRateInfos[core.QuorumID(quorumID)] = QuorumRateInfo{
TotalUnauthThroughput: common.RateParam(c.IntSlice(TotalUnauthThroughputFlagName)[ind]),
PerUserUnauthThroughput: common.RateParam(c.IntSlice(PerUserUnauthThroughputFlagName)[ind]),
TotalUnauthBlobRate: common.RateParam(totalBlobRate * blobRateMultiplier),
PerUserUnauthBlobRate: common.RateParam(accountBlobRate * blobRateMultiplier),
}
}

allowlist := parseAllowlistEntry(c)

return RateConfig{
QuorumRateInfos: quorumRateInfos,
ClientIPHeader: c.String(ClientIPHeaderFlagName),
Expand Down
2 changes: 1 addition & 1 deletion disperser/apiserver/ratelimit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ func TestRetrievalRateLimit(t *testing.T) {
numLimited := 0
tt := time.Now()
for i := 0; i < 15; i++ {
_, err = retrieveBlob(t, dispersalServer, requestID, 1)
_, err = retrieveBlob(dispersalServer, requestID, 1)
fmt.Println(time.Since(tt))
tt = time.Now()
if err != nil && strings.Contains(err.Error(), "request ratelimited: Retrieval blob rate limit") {
Expand Down
10 changes: 7 additions & 3 deletions disperser/apiserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ func NewDispersalServer(
rateConfig RateConfig,
) *DispersalServer {
logger := _logger.With("component", "DispersalServer")
for ip, rateInfoByQuorum := range rateConfig.Allowlist {
for account, rateInfoByQuorum := range rateConfig.Allowlist {
for quorumID, rateInfo := range rateInfoByQuorum {
logger.Info("[Allowlist]", "ip", ip, "quorumID", quorumID, "throughput", rateInfo.Throughput, "blobRate", rateInfo.BlobRate)
logger.Info("[Allowlist]", "account", account, "name", rateInfo.Name, "quorumID", quorumID, "throughput", rateInfo.Throughput, "blobRate", rateInfo.BlobRate)
}
}

Expand Down Expand Up @@ -306,6 +306,7 @@ func (s *DispersalServer) getAccountRate(origin, authenticatedAddress string, qu
}

rates := &PerUserRateInfo{
Name: "",
Throughput: unauthRates.PerUserUnauthThroughput,
BlobRate: unauthRates.PerUserUnauthBlobRate,
}
Expand All @@ -323,6 +324,7 @@ func (s *DispersalServer) getAccountRate(origin, authenticatedAddress string, qu
if rateInfo.BlobRate > 0 {
rates.BlobRate = rateInfo.BlobRate
}
rates.Name = rateInfo.Name
return rates, key, nil
}
}
Expand Down Expand Up @@ -429,7 +431,7 @@ func (s *DispersalServer) checkRateLimitsAndAddRatesToHeader(ctx context.Context

blobSize := len(blob.Data)
length := encoding.GetBlobLength(uint(blobSize))

requesterName := ""
for i, param := range blob.RequestHeader.SecurityParams {

globalRates, ok := s.rateConfig.QuorumRateInfos[param.QuorumID]
Expand All @@ -443,6 +445,7 @@ func (s *DispersalServer) checkRateLimitsAndAddRatesToHeader(ctx context.Context
s.metrics.HandleInternalFailureRpcRequest(apiMethodName)
return api.NewInternalError(err.Error())
}
requesterName = accountRates.Name

// Update the quorum rate
blob.RequestHeader.SecurityParams[i].QuorumRate = accountRates.Throughput
Expand Down Expand Up @@ -521,6 +524,7 @@ func (s *DispersalServer) checkRateLimitsAndAddRatesToHeader(ctx context.Context
} else if info.RateType == AccountThroughputType || info.RateType == AccountBlobRateType {
s.metrics.HandleAccountRateLimitedRpcRequest(apiMethodName)
s.metrics.HandleAccountRateLimitedRequest(fmt.Sprint(info.QuorumID), blobSize, apiMethodName)
s.logger.Info("request ratelimited", "requesterName", requesterName, "requesterID", params.RequesterID, "rateType", info.RateType.String(), "quorum", info.QuorumID)
}
errorString := fmt.Sprintf("request ratelimited: %s for quorum %d", info.RateType.String(), info.QuorumID)
return api.NewResourceExhaustedError(errorString)
Expand Down
Loading

0 comments on commit 861602b

Please sign in to comment.