Skip to content

Commit

Permalink
move blob validation early
Browse files Browse the repository at this point in the history
  • Loading branch information
ian-shim committed Mar 18, 2024
1 parent a002ac2 commit 9ad5b45
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 63 deletions.
13 changes: 6 additions & 7 deletions core/auth/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ func NewAuthenticator(config AuthConfig) core.BlobRequestAuthenticator {
}

func (*authenticator) AuthenticateBlobRequest(header core.BlobAuthHeader) error {
sig := header.AuthenticationData

// Ensure the signature is 65 bytes (Recovery ID is the last byte)
if sig == nil || len(sig) != 65 {
return fmt.Errorf("signature length is unexpected: %d", len(sig))
}

buf := make([]byte, 4)
binary.BigEndian.PutUint32(buf, header.Nonce)
Expand All @@ -37,19 +43,12 @@ func (*authenticator) AuthenticateBlobRequest(header core.BlobAuthHeader) error
return fmt.Errorf("failed to decode public key (%v): %v", header.AccountID, err)
}

sig := header.AuthenticationData

// Decode public key
pubKey, err := crypto.UnmarshalPubkey(publicKeyBytes)
if err != nil {
return fmt.Errorf("failed to decode public key (%v): %v", header.AccountID, err)
}

// Ensure the signature is 65 bytes (Recovery ID is the last byte)
if sig == nil || len(sig) != 65 {
return fmt.Errorf("signature length is unexpected: %d", len(sig))
}

// Verify the signature
sigPublicKeyECDSA, err := crypto.SigToPub(hash, sig)
if err != nil {
Expand Down
20 changes: 9 additions & 11 deletions core/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,15 @@ type BlobRequestHeader struct {
SecurityParams []*SecurityParam `json:"security_params"`
}

func (h *BlobRequestHeader) Validate() error {
for _, quorum := range h.SecurityParams {
if quorum.QuorumThreshold < quorum.AdversaryThreshold+10 {
return errors.New("invalid request: quorum threshold must be >= 10 + adversary threshold")
}
if quorum.QuorumThreshold > 100 {
return errors.New("invalid request: quorum threshold exceeds 100")
}
if quorum.AdversaryThreshold == 0 {
return errors.New("invalid request: adversary threshold equals 0")
}
func (sp *SecurityParam) Validate() error {
if sp.QuorumThreshold < sp.AdversaryThreshold+10 {
return errors.New("invalid request: quorum threshold must be >= 10 + adversary threshold")
}
if sp.QuorumThreshold > 100 {
return errors.New("invalid request: quorum threshold exceeds 100")
}
if sp.AdversaryThreshold == 0 {
return errors.New("invalid request: adversary threshold equals 0")
}
return nil
}
Expand Down
83 changes: 42 additions & 41 deletions disperser/apiserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,10 @@ func (s *DispersalServer) DisperseBlobAuthenticated(stream pb.Disperser_Disperse
return api.NewInvalidArgError("missing DisperseBlobRequest")
}

blob := getBlobFromRequest(request.DisperseRequest)
blob, err := getBlobFromRequest(request.DisperseRequest)
if err != nil {
return api.NewInvalidArgError(fmt.Sprintf("failed to get blob from request: %v", err))
}

// Get the ethereum address associated with the public key. This is just for convenience so we can put addresses instead of public keys in the allowlist.
// Decode public key
Expand Down Expand Up @@ -171,7 +174,10 @@ func (s *DispersalServer) DisperseBlobAuthenticated(stream pb.Disperser_Disperse

func (s *DispersalServer) DisperseBlob(ctx context.Context, req *pb.DisperseBlobRequest) (*pb.DisperseBlobReply, error) {

blob := getBlobFromRequest(req)
blob, err := getBlobFromRequest(req)
if err != nil {
return nil, api.NewInvalidArgError(fmt.Sprintf("failed to get blob from request: %v", err))
}

reply, err := s.disperseBlob(ctx, blob, "")
if err != nil {
Expand All @@ -181,24 +187,8 @@ func (s *DispersalServer) DisperseBlob(ctx context.Context, req *pb.DisperseBlob
}

func (s *DispersalServer) validateBlobRequest(ctx context.Context, blob *core.Blob) error {

securityParams := blob.RequestHeader.SecurityParams
if len(securityParams) == 0 {
return errors.New("invalid request: security_params must not be empty")
}
if len(securityParams) > 256 {
return errors.New("invalid request: security_params must not exceed 256")
}

seenQuorums := make(map[uint8]struct{})
// The quorum ID must be in range [0, 254]. It'll actually be converted
// to uint8, so it cannot be greater than 254.
for _, param := range securityParams {
if _, ok := seenQuorums[param.QuorumID]; ok {
return errors.New("invalid request: security_params must not contain duplicate quorum_id")
}
seenQuorums[param.QuorumID] = struct{}{}

if param.QuorumID >= s.quorumCount {
err := s.updateQuorumCount(ctx)
if err != nil {
Expand All @@ -211,22 +201,7 @@ func (s *DispersalServer) validateBlobRequest(ctx context.Context, blob *core.Bl
}
}

blobSize := len(blob.Data)
// The blob size in bytes must be in range [1, maxBlobSize].
if blobSize > maxBlobSize {
return errors.New("blob size cannot exceed 2 MiB")
}
if blobSize == 0 {
return errors.New("blob size must be greater than 0")
}

if err := blob.RequestHeader.Validate(); err != nil {
s.logger.Warn("invalid header", "err", err)
return err
}

return nil

}

func (s *DispersalServer) disperseBlob(ctx context.Context, blob *core.Blob, authenticatedAddress string) (*pb.DisperseBlobReply, error) {
Expand Down Expand Up @@ -637,28 +612,54 @@ func getResponseStatus(status disperser.BlobStatus) pb.BlobStatus {
}
}

func getBlobFromRequest(req *pb.DisperseBlobRequest) *core.Blob {
params := make([]*core.SecurityParam, len(req.SecurityParams))
func getBlobFromRequest(req *pb.DisperseBlobRequest) (*core.Blob, error) {
blobSize := len(req.GetData())
// The blob size in bytes must be in range [1, maxBlobSize].
if blobSize > maxBlobSize {
return nil, errors.New("blob size cannot exceed 2 MiB")
}
if blobSize == 0 {
return nil, errors.New("blob size must be greater than 0")
}

securityParams := req.GetSecurityParams()
if len(securityParams) == 0 {
return nil, errors.New("invalid request: security_params must not be empty")
}
if len(securityParams) > 256 {
return nil, errors.New("invalid request: number of security_params must not exceed 256")
}

seenQuorums := make(map[uint8]struct{})
params := make([]*core.SecurityParam, len(req.SecurityParams))
for i, param := range req.GetSecurityParams() {
params[i] = &core.SecurityParam{
QuorumID: core.QuorumID(param.QuorumId),
quorumID := uint8(param.GetQuorumId())
// make sure there are no duplicate quorum IDs
if _, ok := seenQuorums[quorumID]; ok {
return nil, errors.New("invalid request: security_params must not contain duplicate quorum_id")
}
seenQuorums[quorumID] = struct{}{}

sp := &core.SecurityParam{
QuorumID: quorumID,
AdversaryThreshold: uint8(param.AdversaryThreshold),
QuorumThreshold: uint8(param.QuorumThreshold),
}
if err := sp.Validate(); err != nil {
return nil, err
}
params[i] = sp
}

data := req.GetData()

blob := &core.Blob{
RequestHeader: core.BlobRequestHeader{
BlobAuthHeader: core.BlobAuthHeader{
AccountID: req.AccountId,
},
SecurityParams: params,
},
Data: data,
Data: req.GetData(),
}

return blob
return blob, nil
}
4 changes: 2 additions & 2 deletions disperser/apiserver/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func TestDisperseBlobWithInvalidQuorum(t *testing.T) {
},
},
})
assert.Equal(t, err.Error(), "rpc error: code = InvalidArgument desc = invalid request: security_params must not contain duplicate quorum_id")
assert.Equal(t, err.Error(), "rpc error: code = InvalidArgument desc = failed to get blob from request: invalid request: security_params must not contain duplicate quorum_id")
}

func TestGetBlobStatus(t *testing.T) {
Expand Down Expand Up @@ -279,7 +279,7 @@ func TestDisperseBlobWithExceedSizeLimit(t *testing.T) {
},
})
assert.NotNil(t, err)
assert.Equal(t, err.Error(), "rpc error: code = InvalidArgument desc = blob size cannot exceed 2 MiB")
assert.Equal(t, err.Error(), "rpc error: code = InvalidArgument desc = failed to get blob from request: blob size cannot exceed 2 MiB")
}

func setup(m *testing.M) {
Expand Down
11 changes: 9 additions & 2 deletions operators/churner/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,18 @@ func (s *Server) validateChurnRequest(ctx context.Context, req *pb.ChurnRequest)
}

// TODO: ensure that all quorumIDs are valid
if len(req.QuorumIds) == 0 {
return errors.New("invalid quorumIds length")
if len(req.QuorumIds) == 0 || len(req.QuorumIds) > 256 {
return fmt.Errorf("invalid quorumIds length %d", len(req.QuorumIds))
}

seenQuorums := make(map[int]struct{})
for quorumID := range req.GetQuorumIds() {
// make sure there are no duplicate quorum IDs
if _, ok := seenQuorums[quorumID]; ok {
return errors.New("invalid request: security_params must not contain duplicate quorum_id")
}
seenQuorums[quorumID] = struct{}{}

if quorumID >= int(s.churner.QuorumCount) {
err := s.churner.UpdateQuorumCount(ctx)
if err != nil {
Expand Down

0 comments on commit 9ad5b45

Please sign in to comment.