Skip to content

Commit

Permalink
Optimize the batch deserialization perf (#573)
Browse files Browse the repository at this point in the history
  • Loading branch information
jianoaix authored May 21, 2024
1 parent 374500a commit 27b4caa
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 86 deletions.
128 changes: 65 additions & 63 deletions node/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,37 +40,38 @@ var (

// Config contains all of the configuration information for a DA node.
type Config struct {
Hostname string
RetrievalPort string
DispersalPort string
InternalRetrievalPort string
InternalDispersalPort string
EnableNodeApi bool
NodeApiPort string
EnableMetrics bool
MetricsPort string
OnchainMetricsInterval int64
Timeout time.Duration
RegisterNodeAtStart bool
ExpirationPollIntervalSec uint64
EnableTestMode bool
OverrideBlockStaleMeasure int64
OverrideStoreDurationBlocks int64
QuorumIDList []core.QuorumID
DbPath string
LogPath string
PrivateBls string
ID core.OperatorID
BLSOperatorStateRetrieverAddr string
EigenDAServiceManagerAddr string
PubIPProvider string
PubIPCheckInterval time.Duration
ChurnerUrl string
DataApiUrl string
NumBatchValidators int
ClientIPHeader string
UseSecureGrpc bool
ReachabilityPollIntervalSec uint64
Hostname string
RetrievalPort string
DispersalPort string
InternalRetrievalPort string
InternalDispersalPort string
EnableNodeApi bool
NodeApiPort string
EnableMetrics bool
MetricsPort string
OnchainMetricsInterval int64
Timeout time.Duration
RegisterNodeAtStart bool
ExpirationPollIntervalSec uint64
EnableTestMode bool
OverrideBlockStaleMeasure int64
OverrideStoreDurationBlocks int64
QuorumIDList []core.QuorumID
DbPath string
LogPath string
PrivateBls string
ID core.OperatorID
BLSOperatorStateRetrieverAddr string
EigenDAServiceManagerAddr string
PubIPProvider string
PubIPCheckInterval time.Duration
ChurnerUrl string
DataApiUrl string
NumBatchValidators int
NumBatchDeserializationWorkers int
ClientIPHeader string
UseSecureGrpc bool
ReachabilityPollIntervalSec uint64

EthClientConfig geth.EthClientConfig
LoggerConfig common.LoggerConfig
Expand Down Expand Up @@ -167,37 +168,38 @@ func NewConfig(ctx *cli.Context) (*Config, error) {
}

return &Config{
Hostname: ctx.GlobalString(flags.HostnameFlag.Name),
DispersalPort: ctx.GlobalString(flags.DispersalPortFlag.Name),
RetrievalPort: ctx.GlobalString(flags.RetrievalPortFlag.Name),
InternalDispersalPort: internalDispersalFlag,
InternalRetrievalPort: internalRetrievalFlag,
EnableNodeApi: ctx.GlobalBool(flags.EnableNodeApiFlag.Name),
NodeApiPort: ctx.GlobalString(flags.NodeApiPortFlag.Name),
EnableMetrics: ctx.GlobalBool(flags.EnableMetricsFlag.Name),
MetricsPort: ctx.GlobalString(flags.MetricsPortFlag.Name),
OnchainMetricsInterval: ctx.GlobalInt64(flags.OnchainMetricsIntervalFlag.Name),
Timeout: timeout,
RegisterNodeAtStart: registerNodeAtStart,
ExpirationPollIntervalSec: expirationPollIntervalSec,
ReachabilityPollIntervalSec: reachabilityPollIntervalSec,
EnableTestMode: testMode,
OverrideBlockStaleMeasure: ctx.GlobalInt64(flags.OverrideBlockStaleMeasureFlag.Name),
OverrideStoreDurationBlocks: ctx.GlobalInt64(flags.OverrideStoreDurationBlocksFlag.Name),
QuorumIDList: ids,
DbPath: ctx.GlobalString(flags.DbPathFlag.Name),
PrivateBls: privateBls,
EthClientConfig: ethClientConfig,
EncoderConfig: kzg.ReadCLIConfig(ctx),
LoggerConfig: *loggerConfig,
BLSOperatorStateRetrieverAddr: ctx.GlobalString(flags.BlsOperatorStateRetrieverFlag.Name),
EigenDAServiceManagerAddr: ctx.GlobalString(flags.EigenDAServiceManagerFlag.Name),
PubIPProvider: ctx.GlobalString(flags.PubIPProviderFlag.Name),
PubIPCheckInterval: pubIPCheckInterval,
ChurnerUrl: ctx.GlobalString(flags.ChurnerUrlFlag.Name),
DataApiUrl: ctx.GlobalString(flags.DataApiUrlFlag.Name),
NumBatchValidators: ctx.GlobalInt(flags.NumBatchValidatorsFlag.Name),
ClientIPHeader: ctx.GlobalString(flags.ClientIPHeaderFlag.Name),
UseSecureGrpc: ctx.GlobalBoolT(flags.ChurnerUseSecureGRPC.Name),
Hostname: ctx.GlobalString(flags.HostnameFlag.Name),
DispersalPort: ctx.GlobalString(flags.DispersalPortFlag.Name),
RetrievalPort: ctx.GlobalString(flags.RetrievalPortFlag.Name),
InternalDispersalPort: internalDispersalFlag,
InternalRetrievalPort: internalRetrievalFlag,
EnableNodeApi: ctx.GlobalBool(flags.EnableNodeApiFlag.Name),
NodeApiPort: ctx.GlobalString(flags.NodeApiPortFlag.Name),
EnableMetrics: ctx.GlobalBool(flags.EnableMetricsFlag.Name),
MetricsPort: ctx.GlobalString(flags.MetricsPortFlag.Name),
OnchainMetricsInterval: ctx.GlobalInt64(flags.OnchainMetricsIntervalFlag.Name),
Timeout: timeout,
RegisterNodeAtStart: registerNodeAtStart,
ExpirationPollIntervalSec: expirationPollIntervalSec,
ReachabilityPollIntervalSec: reachabilityPollIntervalSec,
EnableTestMode: testMode,
OverrideBlockStaleMeasure: ctx.GlobalInt64(flags.OverrideBlockStaleMeasureFlag.Name),
OverrideStoreDurationBlocks: ctx.GlobalInt64(flags.OverrideStoreDurationBlocksFlag.Name),
QuorumIDList: ids,
DbPath: ctx.GlobalString(flags.DbPathFlag.Name),
PrivateBls: privateBls,
EthClientConfig: ethClientConfig,
EncoderConfig: kzg.ReadCLIConfig(ctx),
LoggerConfig: *loggerConfig,
BLSOperatorStateRetrieverAddr: ctx.GlobalString(flags.BlsOperatorStateRetrieverFlag.Name),
EigenDAServiceManagerAddr: ctx.GlobalString(flags.EigenDAServiceManagerFlag.Name),
PubIPProvider: ctx.GlobalString(flags.PubIPProviderFlag.Name),
PubIPCheckInterval: pubIPCheckInterval,
ChurnerUrl: ctx.GlobalString(flags.ChurnerUrlFlag.Name),
DataApiUrl: ctx.GlobalString(flags.DataApiUrlFlag.Name),
NumBatchValidators: ctx.GlobalInt(flags.NumBatchValidatorsFlag.Name),
NumBatchDeserializationWorkers: ctx.GlobalInt(flags.NumBatchDeserializationWorkersFlag.Name),
ClientIPHeader: ctx.GlobalString(flags.ClientIPHeaderFlag.Name),
UseSecureGrpc: ctx.GlobalBoolT(flags.ChurnerUseSecureGRPC.Name),
}, nil
}
7 changes: 7 additions & 0 deletions node/flags/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,13 @@ var (
EnvVar: common.PrefixEnvVar(EnvVarPrefix, "NUM_BATCH_VALIDATORS"),
Value: 128,
}
NumBatchDeserializationWorkersFlag = cli.IntFlag{
Name: "num-batch-deserialization-workers",
Usage: "maximum number of parallel workers used to deserialize a batch (defaults to 128)",
Required: false,
EnvVar: common.PrefixEnvVar(EnvVarPrefix, "NUM_BATCH_DESERIALIZATION_WORKERS"),
Value: 128,
}

// Test only, DO NOT USE the following flags in production

Expand Down
2 changes: 1 addition & 1 deletion node/grpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ func (s *Server) handleStoreChunksRequest(ctx context.Context, in *pb.StoreChunk
return nil, err
}

blobs, err := GetBlobMessages(in)
blobs, err := GetBlobMessages(in, s.node.Config.NumBatchDeserializationWorkers)
if err != nil {
return nil, err
}
Expand Down
63 changes: 41 additions & 22 deletions node/grpc/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/Layr-Labs/eigenda/node"
"github.com/consensys/gnark-crypto/ecc/bn254"
"github.com/consensys/gnark-crypto/ecc/bn254/fp"
"github.com/gammazero/workerpool"
"github.com/wealdtech/go-merkletree"
"github.com/wealdtech/go-merkletree/keccak256"
"google.golang.org/protobuf/proto"
Expand All @@ -34,34 +35,52 @@ func GetBatchHeader(in *pb.StoreChunksRequest) (*core.BatchHeader, error) {
// GetBlobMessages constructs a core.BlobMessage array from a proto of pb.StoreChunksRequest.
// Note the StoreChunksRequest is validated as soon as it enters the node gRPC
// interface, see grpc.Server.validateStoreChunkRequest.
func GetBlobMessages(in *pb.StoreChunksRequest) ([]*core.BlobMessage, error) {
func GetBlobMessages(in *pb.StoreChunksRequest, numWorkers int) ([]*core.BlobMessage, error) {
blobs := make([]*core.BlobMessage, len(in.GetBlobs()))
pool := workerpool.New(numWorkers)
resultChan := make(chan error, len(blobs))
for i, blob := range in.GetBlobs() {
blobHeader, err := GetBlobHeaderFromProto(blob.GetHeader())

if err != nil {
return nil, err
}
if len(blob.GetBundles()) != len(blob.GetHeader().GetQuorumHeaders()) {
return nil, fmt.Errorf("number of quorum headers (%d) does not match number of bundles in blob message (%d)", len(blob.GetHeader().GetQuorumHeaders()), len(blob.GetBundles()))
}
i := i
blob := blob
pool.Submit(func() {
blobHeader, err := GetBlobHeaderFromProto(blob.GetHeader())

if err != nil {
resultChan <- err
return
}
if len(blob.GetBundles()) != len(blob.GetHeader().GetQuorumHeaders()) {
resultChan <- fmt.Errorf("number of quorum headers (%d) does not match number of bundles in blob message (%d)", len(blob.GetHeader().GetQuorumHeaders()), len(blob.GetBundles()))
return
}

bundles := make(map[core.QuorumID]core.Bundle, len(blob.GetBundles()))
for j, chunks := range blob.GetBundles() {
quorumID := blob.GetHeader().GetQuorumHeaders()[j].GetQuorumId()
bundles[uint8(quorumID)] = make([]*encoding.Frame, len(chunks.GetChunks()))
for k, data := range chunks.GetChunks() {
chunk, err := new(encoding.Frame).Deserialize(data)
if err != nil {
return nil, err
bundles := make(map[core.QuorumID]core.Bundle, len(blob.GetBundles()))
for j, chunks := range blob.GetBundles() {
quorumID := blob.GetHeader().GetQuorumHeaders()[j].GetQuorumId()
bundles[uint8(quorumID)] = make([]*encoding.Frame, len(chunks.GetChunks()))
for k, data := range chunks.GetChunks() {
chunk, err := new(encoding.Frame).Deserialize(data)
if err != nil {
resultChan <- err
return
}
bundles[uint8(quorumID)][k] = chunk
}
bundles[uint8(quorumID)][k] = chunk
}
}

blobs[i] = &core.BlobMessage{
BlobHeader: blobHeader,
Bundles: bundles,
blobs[i] = &core.BlobMessage{
BlobHeader: blobHeader,
Bundles: bundles,
}

resultChan <- nil
})
}
pool.StopWait()
close(resultChan)
for err := range resultChan {
if err != nil {
return nil, err
}
}
return blobs, nil
Expand Down

0 comments on commit 27b4caa

Please sign in to comment.