diff --git a/common/aws/cli.go b/common/aws/cli.go index 9a4a51b744..5a6d11503b 100644 --- a/common/aws/cli.go +++ b/common/aws/cli.go @@ -3,48 +3,20 @@ package aws import ( "github.com/Layr-Labs/eigenda/common" "github.com/urfave/cli" - "time" ) var ( - RegionFlagName = "aws.region" - AccessKeyIdFlagName = "aws.access-key-id" - SecretAccessKeyFlagName = "aws.secret-access-key" - EndpointURLFlagName = "aws.endpoint-url" - FragmentPrefixCharsFlagName = "aws.fragment-prefix-chars" - FragmentParallelismFactorFlagName = "aws.fragment-parallelism-factor" - FragmentParallelismConstantFlagName = "aws.fragment-parallelism-constant" - FragmentReadTimeoutFlagName = "aws.fragment-read-timeout" - FragmentWriteTimeoutFlagName = "aws.fragment-write-timeout" + RegionFlagName = "aws.region" + AccessKeyIdFlagName = "aws.access-key-id" + SecretAccessKeyFlagName = "aws.secret-access-key" + EndpointURLFlagName = "aws.endpoint-url" ) type ClientConfig struct { - // Region is the region to use when interacting with S3. Default is "us-east-2". - Region string - // AccessKey to use when interacting with S3. - AccessKey string - // SecretAccessKey to use when interacting with S3. + Region string + AccessKey string SecretAccessKey string - // EndpointURL of the S3 endpoint to use. If this is not set then the default AWS S3 endpoint will be used. - EndpointURL string - - // FragmentPrefixChars is the number of characters of the key to use as the prefix for fragmented files. - // A value of "3" for the key "ABCDEFG" will result in the prefix "ABC". Default is 3. - FragmentPrefixChars int - // FragmentParallelismFactor helps determine the size of the pool of workers to help upload/download files. - // A non-zero value for this parameter adds a number of workers equal to the number of cores times this value. - // Default is 8. In general, the number of workers here can be a lot larger than the number of cores because the - // workers will be blocked on I/O most of the time. - FragmentParallelismFactor int - // FragmentParallelismConstant helps determine the size of the pool of workers to help upload/download files. - // A non-zero value for this parameter adds a constant number of workers. Default is 0. - FragmentParallelismConstant int - // FragmentReadTimeout is used to bound the maximum time to wait for a single fragmented read. - // Default is 30 seconds. - FragmentReadTimeout time.Duration - // FragmentWriteTimeout is used to bound the maximum time to wait for a single fragmented write. - // Default is 30 seconds. - FragmentWriteTimeout time.Duration + EndpointURL string } func ClientFlags(envPrefix string, flagPrefix string) []cli.Flag { @@ -76,66 +48,14 @@ func ClientFlags(envPrefix string, flagPrefix string) []cli.Flag { Value: "", EnvVar: common.PrefixEnvVar(envPrefix, "AWS_ENDPOINT_URL"), }, - cli.IntFlag{ - Name: common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName), - Usage: "The number of characters of the key to use as the prefix for fragmented files", - Required: false, - Value: 3, - EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PREFIX_CHARS"), - }, - cli.IntFlag{ - Name: common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName), - Usage: "Add this many threads times the number of cores to the worker pool", - Required: false, - Value: 8, - EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PARALLELISM_FACTOR"), - }, - cli.IntFlag{ - Name: common.PrefixFlag(flagPrefix, FragmentParallelismConstantFlagName), - Usage: "Add this many threads to the worker pool", - Required: false, - Value: 0, - EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PARALLELISM_CONSTANT"), - }, - cli.DurationFlag{ - Name: common.PrefixFlag(flagPrefix, FragmentReadTimeoutFlagName), - Usage: "The maximum time to wait for a single fragmented read", - Required: false, - Value: 30 * time.Second, - EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_READ_TIMEOUT"), - }, - cli.DurationFlag{ - Name: common.PrefixFlag(flagPrefix, FragmentWriteTimeoutFlagName), - Usage: "The maximum time to wait for a single fragmented write", - Required: false, - Value: 30 * time.Second, - EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_WRITE_TIMEOUT"), - }, } } func ReadClientConfig(ctx *cli.Context, flagPrefix string) ClientConfig { return ClientConfig{ - Region: ctx.GlobalString(common.PrefixFlag(flagPrefix, RegionFlagName)), - AccessKey: ctx.GlobalString(common.PrefixFlag(flagPrefix, AccessKeyIdFlagName)), - SecretAccessKey: ctx.GlobalString(common.PrefixFlag(flagPrefix, SecretAccessKeyFlagName)), - EndpointURL: ctx.GlobalString(common.PrefixFlag(flagPrefix, EndpointURLFlagName)), - FragmentPrefixChars: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentPrefixCharsFlagName)), - FragmentParallelismFactor: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName)), - FragmentParallelismConstant: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentParallelismConstantFlagName)), - FragmentReadTimeout: ctx.GlobalDuration(common.PrefixFlag(flagPrefix, FragmentReadTimeoutFlagName)), - FragmentWriteTimeout: ctx.GlobalDuration(common.PrefixFlag(flagPrefix, FragmentWriteTimeoutFlagName)), - } -} - -// DefaultClientConfig returns a new ClientConfig with default values. -func DefaultClientConfig() *ClientConfig { - return &ClientConfig{ - Region: "us-east-2", - FragmentPrefixChars: 3, - FragmentParallelismFactor: 8, - FragmentParallelismConstant: 0, - FragmentReadTimeout: 30 * time.Second, - FragmentWriteTimeout: 30 * time.Second, + Region: ctx.GlobalString(common.PrefixFlag(flagPrefix, RegionFlagName)), + AccessKey: ctx.GlobalString(common.PrefixFlag(flagPrefix, AccessKeyIdFlagName)), + SecretAccessKey: ctx.GlobalString(common.PrefixFlag(flagPrefix, SecretAccessKeyFlagName)), + EndpointURL: ctx.GlobalString(common.PrefixFlag(flagPrefix, EndpointURLFlagName)), } } diff --git a/common/aws/s3/client.go b/common/aws/s3/client.go index 8a88318117..231d546ae6 100644 --- a/common/aws/s3/client.go +++ b/common/aws/s3/client.go @@ -4,8 +4,6 @@ import ( "bytes" "context" "errors" - "github.com/gammazero/workerpool" - "runtime" "sync" commonaws "github.com/Layr-Labs/eigenda/common/aws" @@ -29,9 +27,7 @@ type Object struct { } type client struct { - cfg *commonaws.ClientConfig s3Client *s3.Client - pool *workerpool.WorkerPool logger logging.Logger } @@ -40,19 +36,18 @@ var _ Client = (*client)(nil) func NewClient(ctx context.Context, cfg commonaws.ClientConfig, logger logging.Logger) (*client, error) { var err error once.Do(func() { - customResolver := aws.EndpointResolverWithOptionsFunc( - func(service, region string, options ...interface{}) (aws.Endpoint, error) { - if cfg.EndpointURL != "" { - return aws.Endpoint{ - PartitionID: "aws", - URL: cfg.EndpointURL, - SigningRegion: cfg.Region, - }, nil - } - - // returning EndpointNotFoundError will allow the service to fallback to its default resolution - return aws.Endpoint{}, &aws.EndpointNotFoundError{} - }) + customResolver := aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...interface{}) (aws.Endpoint, error) { + if cfg.EndpointURL != "" { + return aws.Endpoint{ + PartitionID: "aws", + URL: cfg.EndpointURL, + SigningRegion: cfg.Region, + }, nil + } + + // returning EndpointNotFoundError will allow the service to fallback to its default resolution + return aws.Endpoint{}, &aws.EndpointNotFoundError{} + }) options := [](func(*config.LoadOptions) error){ config.WithRegion(cfg.Region), @@ -61,9 +56,7 @@ func NewClient(ctx context.Context, cfg commonaws.ClientConfig, logger logging.L } // If access key and secret access key are not provided, use the default credential provider if len(cfg.AccessKey) > 0 && len(cfg.SecretAccessKey) > 0 { - options = append(options, - config.WithCredentialsProvider( - credentials.NewStaticCredentialsProvider(cfg.AccessKey, cfg.SecretAccessKey, ""))) + options = append(options, config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(cfg.AccessKey, cfg.SecretAccessKey, ""))) } awsConfig, errCfg := config.LoadDefaultConfig(context.Background(), options...) @@ -71,34 +64,25 @@ func NewClient(ctx context.Context, cfg commonaws.ClientConfig, logger logging.L err = errCfg return } - s3Client := s3.NewFromConfig(awsConfig, func(o *s3.Options) { o.UsePathStyle = true }) - - workers := 0 - if cfg.FragmentParallelismConstant > 0 { - workers = cfg.FragmentParallelismConstant - } - if cfg.FragmentParallelismFactor > 0 { - workers = cfg.FragmentParallelismFactor * runtime.NumCPU() - } - - if workers == 0 { - workers = 1 - } - pool := workerpool.New(workers) - - ref = &client{ - cfg: &cfg, - s3Client: s3Client, - pool: pool, - logger: logger.With("component", "S3Client"), - } + ref = &client{s3Client: s3Client, logger: logger.With("component", "S3Client")} }) return ref, err } +func (s *client) CreateBucket(ctx context.Context, bucket string) error { + _, err := s.s3Client.CreateBucket(ctx, &s3.CreateBucketInput{ + Bucket: aws.String(bucket), + }) + if err != nil { + return err + } + + return nil +} + func (s *client) DownloadObject(ctx context.Context, bucket string, key string) ([]byte, error) { var partMiBs int64 = 10 downloader := manager.NewDownloader(s.s3Client, func(d *manager.Downloader) { @@ -175,162 +159,3 @@ func (s *client) ListObjects(ctx context.Context, bucket string, prefix string) } return objects, nil } - -func (s *client) CreateBucket(ctx context.Context, bucket string) error { - _, err := s.s3Client.CreateBucket(ctx, &s3.CreateBucketInput{ - Bucket: aws.String(bucket), - }) - if err != nil { - return err - } - - return nil -} - -func (s *client) FragmentedUploadObject( - ctx context.Context, - bucket string, - key string, - data []byte, - fragmentSize int) error { - - fragments, err := BreakIntoFragments(key, data, s.cfg.FragmentPrefixChars, fragmentSize) - if err != nil { - return err - } - resultChannel := make(chan error, len(fragments)) - - ctx, cancel := context.WithTimeout(ctx, s.cfg.FragmentWriteTimeout) - defer cancel() - - for _, fragment := range fragments { - fragmentCapture := fragment - s.pool.Submit(func() { - s.fragmentedWriteTask(ctx, resultChannel, fragmentCapture, bucket) - }) - } - - for range fragments { - err := <-resultChannel - if err != nil { - return err - } - } - return ctx.Err() - -} - -// fragmentedWriteTask writes a single file to S3. -func (s *client) fragmentedWriteTask( - ctx context.Context, - resultChannel chan error, - fragment *Fragment, - bucket string) { - - _, err := s.s3Client.PutObject(ctx, - &s3.PutObjectInput{ - Bucket: aws.String(bucket), - Key: aws.String(fragment.FragmentKey), - Body: bytes.NewReader(fragment.Data), - }) - - resultChannel <- err -} - -func (s *client) FragmentedDownloadObject( - ctx context.Context, - bucket string, - key string, - fileSize int, - fragmentSize int) ([]byte, error) { - - if fragmentSize <= 0 { - return nil, errors.New("fragmentSize must be greater than 0") - } - - fragmentKeys, err := GetFragmentKeys(key, s.cfg.FragmentPrefixChars, GetFragmentCount(fileSize, fragmentSize)) - if err != nil { - return nil, err - } - resultChannel := make(chan *readResult, len(fragmentKeys)) - - ctx, cancel := context.WithTimeout(ctx, s.cfg.FragmentWriteTimeout) - defer cancel() - - for i, fragmentKey := range fragmentKeys { - boundFragmentKey := fragmentKey - boundI := i - s.pool.Submit(func() { - s.readTask(ctx, resultChannel, bucket, boundFragmentKey, boundI) - }) - } - - fragments := make([]*Fragment, len(fragmentKeys)) - for i := 0; i < len(fragmentKeys); i++ { - result := <-resultChannel - if result.err != nil { - return nil, result.err - } - fragments[result.fragment.Index] = result.fragment - } - - if ctx.Err() != nil { - return nil, ctx.Err() - } - - return RecombineFragments(fragments) - -} - -// readResult is the result of a read task. -type readResult struct { - fragment *Fragment - err error -} - -// readTask reads a single file from S3. -func (s *client) readTask( - ctx context.Context, - resultChannel chan *readResult, - bucket string, - key string, - index int) { - - result := &readResult{} - defer func() { - resultChannel <- result - }() - - ret, err := s.s3Client.GetObject(ctx, &s3.GetObjectInput{ - Bucket: aws.String(bucket), - Key: aws.String(key), - }) - - if err != nil { - result.err = err - return - } - - data := make([]byte, *ret.ContentLength) - bytesRead := 0 - - for bytesRead < len(data) && ctx.Err() == nil { - count, err := ret.Body.Read(data[bytesRead:]) - if err != nil && err.Error() != "EOF" { - result.err = err - return - } - bytesRead += count - } - - result.fragment = &Fragment{ - FragmentKey: key, - Data: data, - Index: index, - } - - err = ret.Body.Close() - if err != nil { - result.err = err - } -} diff --git a/common/aws/s3/fragment.go b/common/aws/s3/fragment.go deleted file mode 100644 index 6f978fbdc6..0000000000 --- a/common/aws/s3/fragment.go +++ /dev/null @@ -1,128 +0,0 @@ -package s3 - -import ( - "fmt" - "sort" - "strings" -) - -// GetFragmentCount returns the number of fragments that a file of the given size will be broken into. -func GetFragmentCount(fileSize int, fragmentSize int) int { - if fileSize < fragmentSize { - return 1 - } else if fileSize%fragmentSize == 0 { - return fileSize / fragmentSize - } else { - return fileSize/fragmentSize + 1 - } -} - -// GetFragmentKey returns the key for the fragment at the given index. -// -// Fragment keys take the form of "prefix/body-index[f]". The prefix is the first prefixLength characters -// of the file key. The body is the file key. The index is the index of the fragment. The character "f" is appended -// to the key of the last fragment in the series. -// -// Example: fileKey="abc123", prefixLength=2, fragmentCount=3 -// The keys will be "ab/abc123-0", "ab/abc123-1", "ab/abc123-2f" -func GetFragmentKey(fileKey string, prefixLength int, fragmentCount int, index int) (string, error) { - var prefix string - if prefixLength > len(fileKey) { - prefix = fileKey - } else { - prefix = fileKey[:prefixLength] - } - - postfix := "" - if fragmentCount-1 == index { - postfix = "f" - } - - if index >= fragmentCount { - return "", fmt.Errorf("index %d is too high for fragment count %d", index, fragmentCount) - } - - return fmt.Sprintf("%s/%s-%d%s", prefix, fileKey, index, postfix), nil -} - -// Fragment is a subset of a file. -type Fragment struct { - FragmentKey string - Data []byte - Index int -} - -// BreakIntoFragments breaks a file into fragments of the given size. -func BreakIntoFragments(fileKey string, data []byte, prefixLength int, fragmentSize int) ([]*Fragment, error) { - fragmentCount := GetFragmentCount(len(data), fragmentSize) - fragments := make([]*Fragment, fragmentCount) - for i := 0; i < fragmentCount; i++ { - start := i * fragmentSize - end := start + fragmentSize - if end > len(data) { - end = len(data) - } - - fragmentKey, err := GetFragmentKey(fileKey, prefixLength, fragmentCount, i) - if err != nil { - return nil, err - } - fragments[i] = &Fragment{ - FragmentKey: fragmentKey, - Data: data[start:end], - Index: i, - } - } - return fragments, nil -} - -// GetFragmentKeys returns the keys for all fragments of a file. -func GetFragmentKeys(fileKey string, prefixLength int, fragmentCount int) ([]string, error) { - keys := make([]string, fragmentCount) - for i := 0; i < fragmentCount; i++ { - fragmentKey, err := GetFragmentKey(fileKey, prefixLength, fragmentCount, i) - if err != nil { - return nil, err - } - keys[i] = fragmentKey - } - return keys, nil -} - -// RecombineFragments recombines fragments into a single file. -// Returns an error if any fragments are missing. -func RecombineFragments(fragments []*Fragment) ([]byte, error) { - - if len(fragments) == 0 { - return nil, fmt.Errorf("no fragments") - } - - // Sort the fragments by index - sort.Slice(fragments, func(i, j int) bool { - return fragments[i].Index < fragments[j].Index - }) - - // Make sure there aren't any gaps in the fragment indices - dataSize := 0 - for i, fragment := range fragments { - if fragment.Index != i { - return nil, fmt.Errorf("missing fragment with index %d", i) - } - dataSize += len(fragment.Data) - } - - // Make sure we have the last fragment - if !strings.HasSuffix(fragments[len(fragments)-1].FragmentKey, "f") { - return nil, fmt.Errorf("missing final fragment") - } - - fragmentSize := len(fragments[0].Data) - - // Concatenate the data - result := make([]byte, dataSize) - for _, fragment := range fragments { - copy(result[fragment.Index*fragmentSize:], fragment.Data) - } - - return result, nil -} diff --git a/common/aws/s3/s3.go b/common/aws/s3/s3.go index 74089099a9..475f68c941 100644 --- a/common/aws/s3/s3.go +++ b/common/aws/s3/s3.go @@ -2,47 +2,10 @@ package s3 import "context" -// Client encapsulates the functionality of an S3 client. type Client interface { - - // DownloadObject downloads an object from S3. + CreateBucket(ctx context.Context, bucket string) error DownloadObject(ctx context.Context, bucket string, key string) ([]byte, error) - - // UploadObject uploads an object to S3. UploadObject(ctx context.Context, bucket string, key string, data []byte) error - - // DeleteObject deletes an object from S3. DeleteObject(ctx context.Context, bucket string, key string) error - - // ListObjects lists all objects in a bucket with the given prefix. Note that this method may return - // file fragments if the bucket contains files uploaded via FragmentedUploadObject. ListObjects(ctx context.Context, bucket string, prefix string) ([]Object, error) - - // CreateBucket creates a bucket in S3. - CreateBucket(ctx context.Context, bucket string) error - - // FragmentedUploadObject uploads a file to S3. The fragmentSize parameter specifies the maximum size of each - // file uploaded to S3. If the file is larger than fragmentSize then it will be broken into - // smaller parts and uploaded in parallel. The file will be reassembled on download. - // - // Note: if a file is uploaded with this method, only the FragmentedDownloadObject method should be used to - // download the file. It is not advised to use DeleteObject on files uploaded with this method (if such - // functionality is required, a new method to do so should be added to this interface). - FragmentedUploadObject( - ctx context.Context, - bucket string, - key string, - data []byte, - fragmentSize int) error - - // FragmentedDownloadObject downloads a file from S3, as written by Upload. The fileSize (in bytes) and fragmentSize - // must be the same as the values used in the FragmentedUploadObject call. - // - // Note: this method can only be used to download files that were uploaded with the FragmentedUploadObject method. - FragmentedDownloadObject( - ctx context.Context, - bucket string, - key string, - fileSize int, - fragmentSize int) ([]byte, error) } diff --git a/common/aws/test/client_test.go b/common/aws/test/client_test.go deleted file mode 100644 index 0f5bc4087a..0000000000 --- a/common/aws/test/client_test.go +++ /dev/null @@ -1,176 +0,0 @@ -package test - -import ( - "context" - "github.com/Layr-Labs/eigenda/common" - "github.com/Layr-Labs/eigenda/common/aws" - "github.com/Layr-Labs/eigenda/common/aws/s3" - "github.com/Layr-Labs/eigenda/common/mock" - tu "github.com/Layr-Labs/eigenda/common/testutils" - "github.com/Layr-Labs/eigenda/inabox/deploy" - "github.com/ory/dockertest/v3" - "github.com/stretchr/testify/assert" - "math/rand" - "os" - "testing" -) - -var ( - dockertestPool *dockertest.Pool - dockertestResource *dockertest.Resource -) - -const ( - localstackPort = "4570" - localstackHost = "http://0.0.0.0:4570" - bucket = "eigen-test" -) - -type clientBuilder struct { - // This method is called at the beginning of the test. - start func() error - // This method is called to build a new client. - build func() (s3.Client, error) - // This method is called at the end of the test when all operations are done. - finish func() error -} - -var clientBuilders = []*clientBuilder{ - { - start: func() error { - return nil - }, - build: func() (s3.Client, error) { - return mock.NewS3Client(), nil - }, - finish: func() error { - return nil - }, - }, - { - start: func() error { - return setupLocalstack() - }, - build: func() (s3.Client, error) { - - logger, err := common.NewLogger(common.DefaultLoggerConfig()) - if err != nil { - return nil, err - } - - config := aws.DefaultClientConfig() - config.EndpointURL = localstackHost - config.Region = "us-east-1" - - err = os.Setenv("AWS_ACCESS_KEY_ID", "localstack") - if err != nil { - return nil, err - } - err = os.Setenv("AWS_SECRET_ACCESS_KEY", "localstack") - if err != nil { - return nil, err - } - - client, err := s3.NewClient(context.Background(), *config, logger) - if err != nil { - return nil, err - } - - err = client.CreateBucket(context.Background(), bucket) - if err != nil { - return nil, err - } - - return client, nil - }, - finish: func() error { - teardownLocalstack() - return nil - }, - }, -} - -func setupLocalstack() error { - deployLocalStack := !(os.Getenv("DEPLOY_LOCALSTACK") == "false") - - if deployLocalStack { - var err error - dockertestPool, dockertestResource, err = deploy.StartDockertestWithLocalstackContainer(localstackPort) - if err != nil && err.Error() == "container already exists" { - teardownLocalstack() - return err - } - } - return nil -} - -func teardownLocalstack() { - deployLocalStack := !(os.Getenv("DEPLOY_LOCALSTACK") == "false") - - if deployLocalStack { - deploy.PurgeDockertestResources(dockertestPool, dockertestResource) - } -} - -func RandomOperationsTest(t *testing.T, client s3.Client) { - numberToWrite := 100 - expectedData := make(map[string][]byte) - - fragmentSize := rand.Intn(1000) + 1000 - - for i := 0; i < numberToWrite; i++ { - key := tu.RandomString(10) - fragmentMultiple := rand.Float64() * 10 - dataSize := int(fragmentMultiple*float64(fragmentSize)) + 1 - data := tu.RandomBytes(dataSize) - expectedData[key] = data - - err := client.FragmentedUploadObject(context.Background(), bucket, key, data, fragmentSize) - assert.NoError(t, err) - } - - // Read back the data - for key, expected := range expectedData { - data, err := client.FragmentedDownloadObject(context.Background(), bucket, key, len(expected), fragmentSize) - assert.NoError(t, err) - assert.Equal(t, expected, data) - } -} - -func TestRandomOperations(t *testing.T) { - tu.InitializeRandom() - for _, builder := range clientBuilders { - err := builder.start() - assert.NoError(t, err) - - client, err := builder.build() - assert.NoError(t, err) - RandomOperationsTest(t, client) - - err = builder.finish() - assert.NoError(t, err) - } -} - -func ReadNonExistentValueTest(t *testing.T, client s3.Client) { - _, err := client.FragmentedDownloadObject(context.Background(), bucket, "nonexistent", 1000, 1000) - assert.Error(t, err) - randomKey := tu.RandomString(10) - _, err = client.FragmentedDownloadObject(context.Background(), bucket, randomKey, 0, 0) - assert.Error(t, err) -} - -func TestReadNonExistentValue(t *testing.T) { - tu.InitializeRandom() - for _, builder := range clientBuilders { - err := builder.start() - assert.NoError(t, err) - - client, err := builder.build() - assert.NoError(t, err) - ReadNonExistentValueTest(t, client) - - err = builder.finish() - assert.NoError(t, err) - } -} diff --git a/common/aws/test/fragment_test.go b/common/aws/test/fragment_test.go deleted file mode 100644 index fc5a257731..0000000000 --- a/common/aws/test/fragment_test.go +++ /dev/null @@ -1,330 +0,0 @@ -package test - -import ( - "fmt" - "github.com/Layr-Labs/eigenda/common/aws/s3" - tu "github.com/Layr-Labs/eigenda/common/testutils" - "github.com/stretchr/testify/assert" - "math/rand" - "strings" - "testing" -) - -func TestGetFragmentCount(t *testing.T) { - tu.InitializeRandom() - - // Test a file smaller than a fragment - fileSize := rand.Intn(100) + 100 - fragmentSize := fileSize * 2 - fragmentCount := s3.GetFragmentCount(fileSize, fragmentSize) - assert.Equal(t, 1, fragmentCount) - - // Test a file that can fit in a single fragment - fileSize = rand.Intn(100) + 100 - fragmentSize = fileSize - fragmentCount = s3.GetFragmentCount(fileSize, fragmentSize) - assert.Equal(t, 1, fragmentCount) - - // Test a file that is one byte larger than a fragment - fileSize = rand.Intn(100) + 100 - fragmentSize = fileSize - 1 - fragmentCount = s3.GetFragmentCount(fileSize, fragmentSize) - assert.Equal(t, 2, fragmentCount) - - // Test a file that is one less than a multiple of the fragment size - fragmentSize = rand.Intn(100) + 100 - expectedFragmentCount := rand.Intn(10) + 1 - fileSize = fragmentSize*expectedFragmentCount - 1 - fragmentCount = s3.GetFragmentCount(fileSize, fragmentSize) - assert.Equal(t, expectedFragmentCount, fragmentCount) - - // Test a file that is a multiple of the fragment size - fragmentSize = rand.Intn(100) + 100 - expectedFragmentCount = rand.Intn(10) + 1 - fileSize = fragmentSize * expectedFragmentCount - fragmentCount = s3.GetFragmentCount(fileSize, fragmentSize) - assert.Equal(t, expectedFragmentCount, fragmentCount) - - // Test a file that is one more than a multiple of the fragment size - fragmentSize = rand.Intn(100) + 100 - expectedFragmentCount = rand.Intn(10) + 2 - fileSize = fragmentSize*(expectedFragmentCount-1) + 1 - fragmentCount = s3.GetFragmentCount(fileSize, fragmentSize) - assert.Equal(t, expectedFragmentCount, fragmentCount) -} - -// Fragment keys take the form of "prefix/body-index[f]". Verify the prefix part of the key. -func TestPrefix(t *testing.T) { - tu.InitializeRandom() - - keyLength := rand.Intn(10) + 10 - key := tu.RandomString(keyLength) - - for i := 0; i < keyLength*2; i++ { - fragmentCount := rand.Intn(10) + 10 - fragmentIndex := rand.Intn(fragmentCount) - fragmentKey, err := s3.GetFragmentKey(key, i, fragmentCount, fragmentIndex) - assert.NoError(t, err) - - parts := strings.Split(fragmentKey, "/") - assert.Equal(t, 2, len(parts)) - prefix := parts[0] - - if i >= keyLength { - assert.Equal(t, key, prefix) - } else { - assert.Equal(t, key[:i], prefix) - } - } -} - -// Fragment keys take the form of "prefix/body-index[f]". Verify the body part of the key. -func TestKeyBody(t *testing.T) { - tu.InitializeRandom() - - for i := 0; i < 10; i++ { - keyLength := rand.Intn(10) + 10 - key := tu.RandomString(keyLength) - fragmentCount := rand.Intn(10) + 10 - fragmentIndex := rand.Intn(fragmentCount) - fragmentKey, err := s3.GetFragmentKey(key, rand.Intn(10), fragmentCount, fragmentIndex) - assert.NoError(t, err) - - parts := strings.Split(fragmentKey, "/") - assert.Equal(t, 2, len(parts)) - parts = strings.Split(parts[1], "-") - assert.Equal(t, 2, len(parts)) - body := parts[0] - - assert.Equal(t, key, body) - } -} - -// Fragment keys take the form of "prefix/body-index[f]". Verify the index part of the key. -func TestKeyIndex(t *testing.T) { - tu.InitializeRandom() - - for i := 0; i < 10; i++ { - fragmentCount := rand.Intn(10) + 10 - index := rand.Intn(fragmentCount) - fragmentKey, err := s3.GetFragmentKey(tu.RandomString(10), rand.Intn(10), fragmentCount, index) - assert.NoError(t, err) - - parts := strings.Split(fragmentKey, "/") - assert.Equal(t, 2, len(parts)) - parts = strings.Split(parts[1], "-") - assert.Equal(t, 2, len(parts)) - indexStr := parts[1] - assert.True(t, strings.HasPrefix(indexStr, fmt.Sprintf("%d", index))) - } -} - -// Fragment keys take the form of "prefix/body-index[f]". -// Verify the postfix part of the key, which should be "f" for the last fragment. -func TestKeyPostfix(t *testing.T) { - tu.InitializeRandom() - - segmentCount := rand.Intn(10) + 10 - - for i := 0; i < segmentCount; i++ { - fragmentKey, err := s3.GetFragmentKey(tu.RandomString(10), rand.Intn(10), segmentCount, i) - assert.NoError(t, err) - - if i == segmentCount-1 { - assert.True(t, strings.HasSuffix(fragmentKey, "f")) - } else { - assert.False(t, strings.HasSuffix(fragmentKey, "f")) - } - } -} - -func TestExampleInGodoc(t *testing.T) { - fileKey := "abc123" - prefixLength := 2 - fragmentCount := 3 - fragmentKeys, err := s3.GetFragmentKeys(fileKey, prefixLength, fragmentCount) - assert.NoError(t, err) - assert.Equal(t, 3, len(fragmentKeys)) - assert.Equal(t, "ab/abc123-0", fragmentKeys[0]) - assert.Equal(t, "ab/abc123-1", fragmentKeys[1]) - assert.Equal(t, "ab/abc123-2f", fragmentKeys[2]) -} - -func TestGetFragmentKeys(t *testing.T) { - tu.InitializeRandom() - - fileKey := tu.RandomString(10) - prefixLength := rand.Intn(3) + 1 - fragmentCount := rand.Intn(10) + 10 - - fragmentKeys, err := s3.GetFragmentKeys(fileKey, prefixLength, fragmentCount) - assert.NoError(t, err) - assert.Equal(t, fragmentCount, len(fragmentKeys)) - - for i := 0; i < fragmentCount; i++ { - expectedKey, err := s3.GetFragmentKey(fileKey, prefixLength, fragmentCount, i) - assert.NoError(t, err) - assert.Equal(t, expectedKey, fragmentKeys[i]) - - parts := strings.Split(fragmentKeys[i], "/") - assert.Equal(t, 2, len(parts)) - parsedPrefix := parts[0] - assert.Equal(t, fileKey[:prefixLength], parsedPrefix) - parts = strings.Split(parts[1], "-") - assert.Equal(t, 2, len(parts)) - parsedKey := parts[0] - assert.Equal(t, fileKey, parsedKey) - index := parts[1] - - if i == fragmentCount-1 { - assert.Equal(t, fmt.Sprintf("%d", i)+"f", index) - } else { - assert.Equal(t, fmt.Sprintf("%d", i), index) - } - } -} - -func TestGetFragments(t *testing.T) { - tu.InitializeRandom() - - fileKey := tu.RandomString(10) - data := tu.RandomBytes(1000) - prefixLength := rand.Intn(3) + 1 - fragmentSize := rand.Intn(100) + 100 - - fragments, err := s3.BreakIntoFragments(fileKey, data, prefixLength, fragmentSize) - assert.NoError(t, err) - assert.Equal(t, s3.GetFragmentCount(len(data), fragmentSize), len(fragments)) - - totalSize := 0 - - for i, fragment := range fragments { - fragmentKey, err := s3.GetFragmentKey(fileKey, prefixLength, len(fragments), i) - assert.NoError(t, err) - assert.Equal(t, fragmentKey, fragment.FragmentKey) - - start := i * fragmentSize - end := start + fragmentSize - if end > len(data) { - end = len(data) - } - assert.Equal(t, data[start:end], fragment.Data) - assert.Equal(t, i, fragment.Index) - totalSize += len(fragment.Data) - } - - assert.Equal(t, len(data), totalSize) -} - -func TestGetFragmentsSmallFile(t *testing.T) { - tu.InitializeRandom() - - fileKey := tu.RandomString(10) - data := tu.RandomBytes(10) - prefixLength := rand.Intn(3) + 1 - fragmentSize := rand.Intn(100) + 100 - - fragments, err := s3.BreakIntoFragments(fileKey, data, prefixLength, fragmentSize) - assert.NoError(t, err) - assert.Equal(t, 1, len(fragments)) - - fragmentKey, err := s3.GetFragmentKey(fileKey, prefixLength, 1, 0) - assert.NoError(t, err) - assert.Equal(t, fragmentKey, fragments[0].FragmentKey) - assert.Equal(t, data, fragments[0].Data) - assert.Equal(t, 0, fragments[0].Index) -} - -func TestGetFragmentsExactlyOnePerfectlySizedFile(t *testing.T) { - tu.InitializeRandom() - - fileKey := tu.RandomString(10) - fragmentSize := rand.Intn(100) + 100 - data := tu.RandomBytes(fragmentSize) - prefixLength := rand.Intn(3) + 1 - - fragments, err := s3.BreakIntoFragments(fileKey, data, prefixLength, fragmentSize) - assert.NoError(t, err) - assert.Equal(t, 1, len(fragments)) - - fragmentKey, err := s3.GetFragmentKey(fileKey, prefixLength, 1, 0) - assert.NoError(t, err) - assert.Equal(t, fragmentKey, fragments[0].FragmentKey) - assert.Equal(t, data, fragments[0].Data) - assert.Equal(t, 0, fragments[0].Index) -} - -func TestRecombineFragments(t *testing.T) { - tu.InitializeRandom() - - fileKey := tu.RandomString(10) - data := tu.RandomBytes(1000) - prefixLength := rand.Intn(3) + 1 - fragmentSize := rand.Intn(100) + 100 - - fragments, err := s3.BreakIntoFragments(fileKey, data, prefixLength, fragmentSize) - assert.NoError(t, err) - recombinedData, err := s3.RecombineFragments(fragments) - assert.NoError(t, err) - assert.Equal(t, data, recombinedData) - - // Shuffle the fragments - for i := range fragments { - j := rand.Intn(i + 1) - fragments[i], fragments[j] = fragments[j], fragments[i] - } - - recombinedData, err = s3.RecombineFragments(fragments) - assert.NoError(t, err) - assert.Equal(t, data, recombinedData) -} - -func TestRecombineFragmentsSmallFile(t *testing.T) { - tu.InitializeRandom() - - fileKey := tu.RandomString(10) - data := tu.RandomBytes(10) - prefixLength := rand.Intn(3) + 1 - fragmentSize := rand.Intn(100) + 100 - - fragments, err := s3.BreakIntoFragments(fileKey, data, prefixLength, fragmentSize) - assert.NoError(t, err) - assert.Equal(t, 1, len(fragments)) - recombinedData, err := s3.RecombineFragments(fragments) - assert.NoError(t, err) - assert.Equal(t, data, recombinedData) -} - -func TestMissingFragment(t *testing.T) { - tu.InitializeRandom() - - fileKey := tu.RandomString(10) - data := tu.RandomBytes(1000) - prefixLength := rand.Intn(3) + 1 - fragmentSize := rand.Intn(100) + 100 - - fragments, err := s3.BreakIntoFragments(fileKey, data, prefixLength, fragmentSize) - assert.NoError(t, err) - - fragmentIndexToSkip := rand.Intn(len(fragments)) - fragments = append(fragments[:fragmentIndexToSkip], fragments[fragmentIndexToSkip+1:]...) - - _, err = s3.RecombineFragments(fragments[:len(fragments)-1]) - assert.Error(t, err) -} - -func TestMissingFinalFragment(t *testing.T) { - tu.InitializeRandom() - - fileKey := tu.RandomString(10) - data := tu.RandomBytes(1000) - prefixLength := rand.Intn(3) + 1 - fragmentSize := rand.Intn(100) + 100 - - fragments, err := s3.BreakIntoFragments(fileKey, data, prefixLength, fragmentSize) - assert.NoError(t, err) - fragments = fragments[:len(fragments)-1] - - _, err = s3.RecombineFragments(fragments) - assert.Error(t, err) -} diff --git a/common/mock/s3_client.go b/common/mock/s3_client.go index 7f505d56aa..d4e79645b0 100644 --- a/common/mock/s3_client.go +++ b/common/mock/s3_client.go @@ -17,6 +17,10 @@ func NewS3Client() *S3Client { return &S3Client{bucket: make(map[string][]byte)} } +func (s *S3Client) CreateBucket(ctx context.Context, bucket string) error { + return nil +} + func (s *S3Client) DownloadObject(ctx context.Context, bucket string, key string) ([]byte, error) { data, ok := s.bucket[key] if !ok { @@ -44,30 +48,3 @@ func (s *S3Client) ListObjects(ctx context.Context, bucket string, prefix string } return objects, nil } - -func (s *S3Client) CreateBucket(ctx context.Context, bucket string) error { - return nil -} - -func (s *S3Client) FragmentedUploadObject( - ctx context.Context, - bucket string, - key string, - data []byte, - fragmentSize int) error { - s.bucket[key] = data - return nil -} - -func (s *S3Client) FragmentedDownloadObject( - ctx context.Context, - bucket string, - key string, - fileSize int, - fragmentSize int) ([]byte, error) { - data, ok := s.bucket[key] - if !ok { - return []byte{}, s3.ErrObjectNotFound - } - return data, nil -} diff --git a/inabox/deploy/localstack.go b/inabox/deploy/localstack.go index 6a89bbf6ed..020f807b65 100644 --- a/inabox/deploy/localstack.go +++ b/inabox/deploy/localstack.go @@ -8,7 +8,6 @@ import ( "net/http" "path/filepath" "runtime" - "runtime/debug" "time" "github.com/Layr-Labs/eigenda/common/aws" @@ -21,7 +20,6 @@ import ( ) func StartDockertestWithLocalstackContainer(localStackPort string) (*dockertest.Pool, *dockertest.Resource, error) { - debug.PrintStack() // TODO do not merge fmt.Println("Starting Localstack container") pool, err := dockertest.NewPool("") if err != nil {