Skip to content

Commit

Permalink
S3 relay interface (#833)
Browse files Browse the repository at this point in the history
Signed-off-by: Cody Littley <[email protected]>
  • Loading branch information
cody-littley authored Oct 30, 2024
1 parent 410419a commit aca3040
Show file tree
Hide file tree
Showing 8 changed files with 992 additions and 41 deletions.
102 changes: 91 additions & 11 deletions common/aws/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,48 @@ package aws
import (
"github.com/Layr-Labs/eigenda/common"
"github.com/urfave/cli"
"time"
)

var (
RegionFlagName = "aws.region"
AccessKeyIdFlagName = "aws.access-key-id"
SecretAccessKeyFlagName = "aws.secret-access-key"
EndpointURLFlagName = "aws.endpoint-url"
RegionFlagName = "aws.region"
AccessKeyIdFlagName = "aws.access-key-id"
SecretAccessKeyFlagName = "aws.secret-access-key"
EndpointURLFlagName = "aws.endpoint-url"
FragmentPrefixCharsFlagName = "aws.fragment-prefix-chars"
FragmentParallelismFactorFlagName = "aws.fragment-parallelism-factor"
FragmentParallelismConstantFlagName = "aws.fragment-parallelism-constant"
FragmentReadTimeoutFlagName = "aws.fragment-read-timeout"
FragmentWriteTimeoutFlagName = "aws.fragment-write-timeout"
)

type ClientConfig struct {
Region string
AccessKey string
// Region is the region to use when interacting with S3. Default is "us-east-2".
Region string
// AccessKey to use when interacting with S3.
AccessKey string
// SecretAccessKey to use when interacting with S3.
SecretAccessKey string
EndpointURL string
// EndpointURL of the S3 endpoint to use. If this is not set then the default AWS S3 endpoint will be used.
EndpointURL string

// FragmentPrefixChars is the number of characters of the key to use as the prefix for fragmented files.
// A value of "3" for the key "ABCDEFG" will result in the prefix "ABC". Default is 3.
FragmentPrefixChars int
// FragmentParallelismFactor helps determine the size of the pool of workers to help upload/download files.
// A non-zero value for this parameter adds a number of workers equal to the number of cores times this value.
// Default is 8. In general, the number of workers here can be a lot larger than the number of cores because the
// workers will be blocked on I/O most of the time.
FragmentParallelismFactor int
// FragmentParallelismConstant helps determine the size of the pool of workers to help upload/download files.
// A non-zero value for this parameter adds a constant number of workers. Default is 0.
FragmentParallelismConstant int
// FragmentReadTimeout is used to bound the maximum time to wait for a single fragmented read.
// Default is 30 seconds.
FragmentReadTimeout time.Duration
// FragmentWriteTimeout is used to bound the maximum time to wait for a single fragmented write.
// Default is 30 seconds.
FragmentWriteTimeout time.Duration
}

func ClientFlags(envPrefix string, flagPrefix string) []cli.Flag {
Expand Down Expand Up @@ -48,14 +76,66 @@ func ClientFlags(envPrefix string, flagPrefix string) []cli.Flag {
Value: "",
EnvVar: common.PrefixEnvVar(envPrefix, "AWS_ENDPOINT_URL"),
},
cli.IntFlag{
Name: common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName),
Usage: "The number of characters of the key to use as the prefix for fragmented files",
Required: false,
Value: 3,
EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PREFIX_CHARS"),
},
cli.IntFlag{
Name: common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName),
Usage: "Add this many threads times the number of cores to the worker pool",
Required: false,
Value: 8,
EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PARALLELISM_FACTOR"),
},
cli.IntFlag{
Name: common.PrefixFlag(flagPrefix, FragmentParallelismConstantFlagName),
Usage: "Add this many threads to the worker pool",
Required: false,
Value: 0,
EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_PARALLELISM_CONSTANT"),
},
cli.DurationFlag{
Name: common.PrefixFlag(flagPrefix, FragmentReadTimeoutFlagName),
Usage: "The maximum time to wait for a single fragmented read",
Required: false,
Value: 30 * time.Second,
EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_READ_TIMEOUT"),
},
cli.DurationFlag{
Name: common.PrefixFlag(flagPrefix, FragmentWriteTimeoutFlagName),
Usage: "The maximum time to wait for a single fragmented write",
Required: false,
Value: 30 * time.Second,
EnvVar: common.PrefixEnvVar(envPrefix, "FRAGMENT_WRITE_TIMEOUT"),
},
}
}

func ReadClientConfig(ctx *cli.Context, flagPrefix string) ClientConfig {
return ClientConfig{
Region: ctx.GlobalString(common.PrefixFlag(flagPrefix, RegionFlagName)),
AccessKey: ctx.GlobalString(common.PrefixFlag(flagPrefix, AccessKeyIdFlagName)),
SecretAccessKey: ctx.GlobalString(common.PrefixFlag(flagPrefix, SecretAccessKeyFlagName)),
EndpointURL: ctx.GlobalString(common.PrefixFlag(flagPrefix, EndpointURLFlagName)),
Region: ctx.GlobalString(common.PrefixFlag(flagPrefix, RegionFlagName)),
AccessKey: ctx.GlobalString(common.PrefixFlag(flagPrefix, AccessKeyIdFlagName)),
SecretAccessKey: ctx.GlobalString(common.PrefixFlag(flagPrefix, SecretAccessKeyFlagName)),
EndpointURL: ctx.GlobalString(common.PrefixFlag(flagPrefix, EndpointURLFlagName)),
FragmentPrefixChars: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentPrefixCharsFlagName)),
FragmentParallelismFactor: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentParallelismFactorFlagName)),
FragmentParallelismConstant: ctx.GlobalInt(common.PrefixFlag(flagPrefix, FragmentParallelismConstantFlagName)),
FragmentReadTimeout: ctx.GlobalDuration(common.PrefixFlag(flagPrefix, FragmentReadTimeoutFlagName)),
FragmentWriteTimeout: ctx.GlobalDuration(common.PrefixFlag(flagPrefix, FragmentWriteTimeoutFlagName)),
}
}

// DefaultClientConfig returns a new ClientConfig with default values.
func DefaultClientConfig() *ClientConfig {
return &ClientConfig{
Region: "us-east-2",
FragmentPrefixChars: 3,
FragmentParallelismFactor: 8,
FragmentParallelismConstant: 0,
FragmentReadTimeout: 30 * time.Second,
FragmentWriteTimeout: 30 * time.Second,
}
}
225 changes: 200 additions & 25 deletions common/aws/s3/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"bytes"
"context"
"errors"
"github.com/gammazero/workerpool"
"runtime"
"sync"

commonaws "github.com/Layr-Labs/eigenda/common/aws"
Expand All @@ -27,7 +29,9 @@ type Object struct {
}

type client struct {
cfg *commonaws.ClientConfig
s3Client *s3.Client
pool *workerpool.WorkerPool
logger logging.Logger
}

Expand All @@ -36,18 +40,19 @@ var _ Client = (*client)(nil)
func NewClient(ctx context.Context, cfg commonaws.ClientConfig, logger logging.Logger) (*client, error) {
var err error
once.Do(func() {
customResolver := aws.EndpointResolverWithOptionsFunc(func(service, region string, options ...interface{}) (aws.Endpoint, error) {
if cfg.EndpointURL != "" {
return aws.Endpoint{
PartitionID: "aws",
URL: cfg.EndpointURL,
SigningRegion: cfg.Region,
}, nil
}

// returning EndpointNotFoundError will allow the service to fallback to its default resolution
return aws.Endpoint{}, &aws.EndpointNotFoundError{}
})
customResolver := aws.EndpointResolverWithOptionsFunc(
func(service, region string, options ...interface{}) (aws.Endpoint, error) {
if cfg.EndpointURL != "" {
return aws.Endpoint{
PartitionID: "aws",
URL: cfg.EndpointURL,
SigningRegion: cfg.Region,
}, nil
}

// returning EndpointNotFoundError will allow the service to fallback to its default resolution
return aws.Endpoint{}, &aws.EndpointNotFoundError{}
})

options := [](func(*config.LoadOptions) error){
config.WithRegion(cfg.Region),
Expand All @@ -56,31 +61,42 @@ func NewClient(ctx context.Context, cfg commonaws.ClientConfig, logger logging.L
}
// If access key and secret access key are not provided, use the default credential provider
if len(cfg.AccessKey) > 0 && len(cfg.SecretAccessKey) > 0 {
options = append(options, config.WithCredentialsProvider(credentials.NewStaticCredentialsProvider(cfg.AccessKey, cfg.SecretAccessKey, "")))
options = append(options,
config.WithCredentialsProvider(
credentials.NewStaticCredentialsProvider(cfg.AccessKey, cfg.SecretAccessKey, "")))
}
awsConfig, errCfg := config.LoadDefaultConfig(context.Background(), options...)

if errCfg != nil {
err = errCfg
return
}

s3Client := s3.NewFromConfig(awsConfig, func(o *s3.Options) {
o.UsePathStyle = true
})
ref = &client{s3Client: s3Client, logger: logger.With("component", "S3Client")}
})
return ref, err
}

func (s *client) CreateBucket(ctx context.Context, bucket string) error {
_, err := s.s3Client.CreateBucket(ctx, &s3.CreateBucketInput{
Bucket: aws.String(bucket),
})
if err != nil {
return err
}
workers := 0
if cfg.FragmentParallelismConstant > 0 {
workers = cfg.FragmentParallelismConstant
}
if cfg.FragmentParallelismFactor > 0 {
workers = cfg.FragmentParallelismFactor * runtime.NumCPU()
}

return nil
if workers == 0 {
workers = 1
}
pool := workerpool.New(workers)

ref = &client{
cfg: &cfg,
s3Client: s3Client,
pool: pool,
logger: logger.With("component", "S3Client"),
}
})
return ref, err
}

func (s *client) DownloadObject(ctx context.Context, bucket string, key string) ([]byte, error) {
Expand Down Expand Up @@ -159,3 +175,162 @@ func (s *client) ListObjects(ctx context.Context, bucket string, prefix string)
}
return objects, nil
}

func (s *client) CreateBucket(ctx context.Context, bucket string) error {
_, err := s.s3Client.CreateBucket(ctx, &s3.CreateBucketInput{
Bucket: aws.String(bucket),
})
if err != nil {
return err
}

return nil
}

func (s *client) FragmentedUploadObject(
ctx context.Context,
bucket string,
key string,
data []byte,
fragmentSize int) error {

fragments, err := BreakIntoFragments(key, data, s.cfg.FragmentPrefixChars, fragmentSize)
if err != nil {
return err
}
resultChannel := make(chan error, len(fragments))

ctx, cancel := context.WithTimeout(ctx, s.cfg.FragmentWriteTimeout)
defer cancel()

for _, fragment := range fragments {
fragmentCapture := fragment
s.pool.Submit(func() {
s.fragmentedWriteTask(ctx, resultChannel, fragmentCapture, bucket)
})
}

for range fragments {
err := <-resultChannel
if err != nil {
return err
}
}
return ctx.Err()

}

// fragmentedWriteTask writes a single file to S3.
func (s *client) fragmentedWriteTask(
ctx context.Context,
resultChannel chan error,
fragment *Fragment,
bucket string) {

_, err := s.s3Client.PutObject(ctx,
&s3.PutObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(fragment.FragmentKey),
Body: bytes.NewReader(fragment.Data),
})

resultChannel <- err
}

func (s *client) FragmentedDownloadObject(
ctx context.Context,
bucket string,
key string,
fileSize int,
fragmentSize int) ([]byte, error) {

if fragmentSize <= 0 {
return nil, errors.New("fragmentSize must be greater than 0")
}

fragmentKeys, err := GetFragmentKeys(key, s.cfg.FragmentPrefixChars, GetFragmentCount(fileSize, fragmentSize))
if err != nil {
return nil, err
}
resultChannel := make(chan *readResult, len(fragmentKeys))

ctx, cancel := context.WithTimeout(ctx, s.cfg.FragmentWriteTimeout)
defer cancel()

for i, fragmentKey := range fragmentKeys {
boundFragmentKey := fragmentKey
boundI := i
s.pool.Submit(func() {
s.readTask(ctx, resultChannel, bucket, boundFragmentKey, boundI)
})
}

fragments := make([]*Fragment, len(fragmentKeys))
for i := 0; i < len(fragmentKeys); i++ {
result := <-resultChannel
if result.err != nil {
return nil, result.err
}
fragments[result.fragment.Index] = result.fragment
}

if ctx.Err() != nil {
return nil, ctx.Err()
}

return RecombineFragments(fragments)

}

// readResult is the result of a read task.
type readResult struct {
fragment *Fragment
err error
}

// readTask reads a single file from S3.
func (s *client) readTask(
ctx context.Context,
resultChannel chan *readResult,
bucket string,
key string,
index int) {

result := &readResult{}
defer func() {
resultChannel <- result
}()

ret, err := s.s3Client.GetObject(ctx, &s3.GetObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
})

if err != nil {
result.err = err
return
}

data := make([]byte, *ret.ContentLength)
bytesRead := 0

for bytesRead < len(data) && ctx.Err() == nil {
count, err := ret.Body.Read(data[bytesRead:])
if err != nil && err.Error() != "EOF" {
result.err = err
return
}
bytesRead += count
}

result.fragment = &Fragment{
FragmentKey: key,
Data: data,
Index: index,
}

err = ret.Body.Close()
if err != nil {
result.err = err
}
}
Loading

0 comments on commit aca3040

Please sign in to comment.