Skip to content

Commit

Permalink
make chunk writer idempotent
Browse files Browse the repository at this point in the history
  • Loading branch information
ian-shim committed Nov 11, 2024
1 parent 2661594 commit 4c4f97a
Show file tree
Hide file tree
Showing 8 changed files with 204 additions and 43 deletions.
64 changes: 58 additions & 6 deletions common/aws/mock/s3_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,37 @@ package mock

import (
"context"
"errors"
"strings"

"github.com/Layr-Labs/eigenda/common/aws/s3"
)

type S3Client struct {
bucket map[string][]byte
Called map[string]int
}

var _ s3.Client = (*S3Client)(nil)

func NewS3Client() *S3Client {
return &S3Client{bucket: make(map[string][]byte)}
return &S3Client{
bucket: make(map[string][]byte),
Called: map[string]int{
"DownloadObject": 0,
"HeadObject": 0,
"UploadObject": 0,
"DeleteObject": 0,
"ListObjects": 0,
"CreateBucket": 0,
"FragmentedUploadObject": 0,
"FragmentedDownloadObject": 0,
},
}
}

func (s *S3Client) DownloadObject(ctx context.Context, bucket string, key string) ([]byte, error) {
s.Called["DownloadObject"]++
data, ok := s.bucket[key]
if !ok {
return []byte{}, s3.ErrObjectNotFound
Expand All @@ -26,6 +41,7 @@ func (s *S3Client) DownloadObject(ctx context.Context, bucket string, key string
}

func (s *S3Client) HeadObject(ctx context.Context, bucket string, key string) (*int64, error) {
s.Called["HeadObject"]++
data, ok := s.bucket[key]
if !ok {
return nil, s3.ErrObjectNotFound
Expand All @@ -35,17 +51,20 @@ func (s *S3Client) HeadObject(ctx context.Context, bucket string, key string) (*
}

func (s *S3Client) UploadObject(ctx context.Context, bucket string, key string, data []byte) error {
s.Called["UploadObject"]++
s.bucket[key] = data
return nil
}

func (s *S3Client) DeleteObject(ctx context.Context, bucket string, key string) error {
s.Called["DeleteObject"]++
delete(s.bucket, key)
return nil
}

func (s *S3Client) ListObjects(ctx context.Context, bucket string, prefix string) ([]s3.Object, error) {
objects := make([]s3.Object, 0, 5)
s.Called["ListObjects"]++
objects := make([]s3.Object, 0, 1000)
for k, v := range s.bucket {
if strings.HasPrefix(k, prefix) {
objects = append(objects, s3.Object{Key: k, Size: int64(len(v))})
Expand All @@ -55,6 +74,7 @@ func (s *S3Client) ListObjects(ctx context.Context, bucket string, prefix string
}

func (s *S3Client) CreateBucket(ctx context.Context, bucket string) error {
s.Called["CreateBucket"]++
return nil
}

Expand All @@ -64,7 +84,14 @@ func (s *S3Client) FragmentedUploadObject(
key string,
data []byte,
fragmentSize int) error {
s.bucket[key] = data
s.Called["FragmentedUploadObject"]++
fragments, err := s3.BreakIntoFragments(key, data, fragmentSize)
if err != nil {
return err
}
for _, fragment := range fragments {
s.bucket[fragment.FragmentKey] = fragment.Data
}
return nil
}

Expand All @@ -74,9 +101,34 @@ func (s *S3Client) FragmentedDownloadObject(
key string,
fileSize int,
fragmentSize int) ([]byte, error) {
data, ok := s.bucket[key]
if !ok {
return []byte{}, s3.ErrObjectNotFound
s.Called["FragmentedDownloadObject"]++
if fileSize <= 0 {
return nil, errors.New("fileSize must be greater than 0")
}
if fragmentSize <= 0 {
return nil, errors.New("fragmentSize must be greater than 0")
}

count := 0
if fileSize < fragmentSize {
count = 1
} else if fileSize%fragmentSize == 0 {
count = fileSize / fragmentSize
} else {
count = fileSize/fragmentSize + 1
}
fragmentKeys, err := s3.GetFragmentKeys(key, count)
if err != nil {
return nil, err
}

data := make([]byte, 0, fileSize)
for _, fragmentKey := range fragmentKeys {
fragmentData, ok := s.bucket[fragmentKey]
if !ok {
return nil, s3.ErrObjectNotFound
}
data = append(data, fragmentData...)
}
return data, nil
}
12 changes: 10 additions & 2 deletions common/aws/s3/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/feature/s3/manager"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
"golang.org/x/sync/errgroup"
)

Expand Down Expand Up @@ -130,6 +131,10 @@ func (s *client) HeadObject(ctx context.Context, bucket string, key string) (*in
Key: aws.String(key),
})
if err != nil {
var notFound *types.NotFound
if ok := errors.As(err, &notFound); ok {
return nil, ErrObjectNotFound
}
return nil, err
}

Expand Down Expand Up @@ -209,7 +214,7 @@ func (s *client) FragmentedUploadObject(
data []byte,
fragmentSize int) error {

fragments, err := breakIntoFragments(key, data, fragmentSize)
fragments, err := BreakIntoFragments(key, data, fragmentSize)
if err != nil {
return err
}
Expand Down Expand Up @@ -259,12 +264,15 @@ func (s *client) FragmentedDownloadObject(
key string,
fileSize int,
fragmentSize int) ([]byte, error) {
if fileSize <= 0 {
return nil, errors.New("fileSize must be greater than 0")
}

if fragmentSize <= 0 {
return nil, errors.New("fragmentSize must be greater than 0")
}

fragmentKeys, err := getFragmentKeys(key, getFragmentCount(fileSize, fragmentSize))
fragmentKeys, err := GetFragmentKeys(key, getFragmentCount(fileSize, fragmentSize))
if err != nil {
return nil, err
}
Expand Down
8 changes: 4 additions & 4 deletions common/aws/s3/fragment.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ type Fragment struct {
Index int
}

// breakIntoFragments breaks a file into fragments of the given size.
func breakIntoFragments(fileKey string, data []byte, fragmentSize int) ([]*Fragment, error) {
// BreakIntoFragments breaks a file into fragments of the given size.
func BreakIntoFragments(fileKey string, data []byte, fragmentSize int) ([]*Fragment, error) {
fragmentCount := getFragmentCount(len(data), fragmentSize)
fragments := make([]*Fragment, fragmentCount)
for i := 0; i < fragmentCount; i++ {
Expand All @@ -69,8 +69,8 @@ func breakIntoFragments(fileKey string, data []byte, fragmentSize int) ([]*Fragm
return fragments, nil
}

// getFragmentKeys returns the keys for all fragments of a file.
func getFragmentKeys(fileKey string, fragmentCount int) ([]string, error) {
// GetFragmentKeys returns the keys for all fragments of a file.
func GetFragmentKeys(fileKey string, fragmentCount int) ([]string, error) {
keys := make([]string, fragmentCount)
for i := 0; i < fragmentCount; i++ {
fragmentKey, err := getFragmentKey(fileKey, fragmentCount, i)
Expand Down
23 changes: 12 additions & 11 deletions common/aws/s3/fragment_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ package s3

import (
"fmt"
tu "github.com/Layr-Labs/eigenda/common/testutils"
"github.com/stretchr/testify/assert"
"math/rand"
"strings"
"testing"

tu "github.com/Layr-Labs/eigenda/common/testutils"
"github.com/stretchr/testify/assert"
)

func TestGetFragmentCount(t *testing.T) {
Expand Down Expand Up @@ -115,7 +116,7 @@ func TestKeyPostfix(t *testing.T) {
func TestExampleInGodoc(t *testing.T) {
fileKey := "abc123"
fragmentCount := 3
fragmentKeys, err := getFragmentKeys(fileKey, fragmentCount)
fragmentKeys, err := GetFragmentKeys(fileKey, fragmentCount)
assert.NoError(t, err)
assert.Equal(t, 3, len(fragmentKeys))
assert.Equal(t, "abc123-0", fragmentKeys[0])
Expand All @@ -129,7 +130,7 @@ func TestGetFragmentKeys(t *testing.T) {
fileKey := tu.RandomString(10)
fragmentCount := rand.Intn(10) + 10

fragmentKeys, err := getFragmentKeys(fileKey, fragmentCount)
fragmentKeys, err := GetFragmentKeys(fileKey, fragmentCount)
assert.NoError(t, err)
assert.Equal(t, fragmentCount, len(fragmentKeys))

Expand Down Expand Up @@ -159,7 +160,7 @@ func TestGetFragments(t *testing.T) {
data := tu.RandomBytes(1000)
fragmentSize := rand.Intn(100) + 100

fragments, err := breakIntoFragments(fileKey, data, fragmentSize)
fragments, err := BreakIntoFragments(fileKey, data, fragmentSize)
assert.NoError(t, err)
assert.Equal(t, getFragmentCount(len(data), fragmentSize), len(fragments))

Expand Down Expand Up @@ -190,7 +191,7 @@ func TestGetFragmentsSmallFile(t *testing.T) {
data := tu.RandomBytes(10)
fragmentSize := rand.Intn(100) + 100

fragments, err := breakIntoFragments(fileKey, data, fragmentSize)
fragments, err := BreakIntoFragments(fileKey, data, fragmentSize)
assert.NoError(t, err)
assert.Equal(t, 1, len(fragments))

Expand All @@ -208,7 +209,7 @@ func TestGetFragmentsExactlyOnePerfectlySizedFile(t *testing.T) {
fragmentSize := rand.Intn(100) + 100
data := tu.RandomBytes(fragmentSize)

fragments, err := breakIntoFragments(fileKey, data, fragmentSize)
fragments, err := BreakIntoFragments(fileKey, data, fragmentSize)
assert.NoError(t, err)
assert.Equal(t, 1, len(fragments))

Expand All @@ -226,7 +227,7 @@ func TestRecombineFragments(t *testing.T) {
data := tu.RandomBytes(1000)
fragmentSize := rand.Intn(100) + 100

fragments, err := breakIntoFragments(fileKey, data, fragmentSize)
fragments, err := BreakIntoFragments(fileKey, data, fragmentSize)
assert.NoError(t, err)
recombinedData, err := recombineFragments(fragments)
assert.NoError(t, err)
Expand All @@ -250,7 +251,7 @@ func TestRecombineFragmentsSmallFile(t *testing.T) {
data := tu.RandomBytes(10)
fragmentSize := rand.Intn(100) + 100

fragments, err := breakIntoFragments(fileKey, data, fragmentSize)
fragments, err := BreakIntoFragments(fileKey, data, fragmentSize)
assert.NoError(t, err)
assert.Equal(t, 1, len(fragments))
recombinedData, err := recombineFragments(fragments)
Expand All @@ -265,7 +266,7 @@ func TestMissingFragment(t *testing.T) {
data := tu.RandomBytes(1000)
fragmentSize := rand.Intn(100) + 100

fragments, err := breakIntoFragments(fileKey, data, fragmentSize)
fragments, err := BreakIntoFragments(fileKey, data, fragmentSize)
assert.NoError(t, err)

fragmentIndexToSkip := rand.Intn(len(fragments))
Expand All @@ -282,7 +283,7 @@ func TestMissingFinalFragment(t *testing.T) {
data := tu.RandomBytes(1000)
fragmentSize := rand.Intn(100) + 100

fragments, err := breakIntoFragments(fileKey, data, fragmentSize)
fragments, err := BreakIntoFragments(fileKey, data, fragmentSize)
assert.NoError(t, err)
fragments = fragments[:len(fragments)-1]

Expand Down
26 changes: 14 additions & 12 deletions common/aws/test/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package test

import (
"context"
"math"
"math/rand"
"os"
"testing"
Expand Down Expand Up @@ -118,9 +119,8 @@ func RandomOperationsTest(t *testing.T, client s3.Client) {
expectedData := make(map[string][]byte)

fragmentSize := rand.Intn(1000) + 1000
prefix := "test-"
for i := 0; i < numberToWrite; i++ {
key := prefix + tu.RandomString(10)
key := tu.RandomString(10)
fragmentMultiple := rand.Float64() * 10
dataSize := int(fragmentMultiple*float64(fragmentSize)) + 1
data := tu.RandomBytes(dataSize)
Expand All @@ -134,19 +134,21 @@ func RandomOperationsTest(t *testing.T, client s3.Client) {
data, err := client.FragmentedDownloadObject(context.Background(), bucket, key, len(expected), fragmentSize)
assert.NoError(t, err)
assert.Equal(t, expected, data)
}

// List the objects
objects, err := client.ListObjects(context.Background(), bucket, prefix)
assert.NoError(t, err)
assert.Len(t, objects, numberToWrite)
for _, object := range objects {
assert.Contains(t, expectedData, object.Key)
assert.Equal(t, int64(len(expectedData[object.Key])), object.Size)
// List the objects
objects, err := client.ListObjects(context.Background(), bucket, key)
assert.NoError(t, err)
numFragments := math.Ceil(float64(len(expected)) / float64(fragmentSize))
assert.Len(t, objects, int(numFragments))
totalSize := int64(0)
for _, object := range objects {
totalSize += object.Size
}
assert.Equal(t, int64(len(expected)), totalSize)
}

// Attempt to list non-existent objects
objects, err = client.ListObjects(context.Background(), bucket, "nonexistent")
objects, err := client.ListObjects(context.Background(), bucket, "nonexistent")
assert.NoError(t, err)
assert.Len(t, objects, 0)
}
Expand Down Expand Up @@ -207,7 +209,7 @@ func TestHeadObject(t *testing.T) {
assert.Equal(t, int64(4), *size)

size, err = client.HeadObject(context.Background(), bucket, "nonexistent")
assert.Error(t, err)
assert.ErrorIs(t, err, s3.ErrObjectNotFound)
assert.Nil(t, size)

err = builder.finish()
Expand Down
2 changes: 1 addition & 1 deletion disperser/encoder/setup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ func setup() {
chunkStoreWriter = chunkstore.NewChunkWriter(logger, s3Client, s3BucketName, 512*1024)

// Initialize chunk store reader
chunkStoreReader = chunkstore.NewChunkReader(logger, nil, s3Client, s3BucketName, []uint32{})
chunkStoreReader = chunkstore.NewChunkReader(logger, s3Client, s3BucketName, []uint32{})

var X1, Y1 fp.Element
X1 = *X1.SetBigInt(big.NewInt(1))
Expand Down
Loading

0 comments on commit 4c4f97a

Please sign in to comment.