Skip to content

Commit

Permalink
refactor(cloudtrail): Get S3 keys concurrently
Browse files Browse the repository at this point in the history
To speed up the process of getting all the keys, divide the inputParams
array into chunks and get the keys for each item in the chunk concurrently.

Signed-off-by: Uli Heilmeier <[email protected]>
  • Loading branch information
uhei committed Mar 12, 2024
1 parent 9920d35 commit 6d09ac5
Showing 1 changed file with 84 additions and 42 deletions.
126 changes: 84 additions & 42 deletions plugins/cloudtrail/pkg/cloudtrail/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ const (
sqsMode
)

type listOrigin struct {
prefix *string
startAfter *string
}


type fileInfo struct {
name string
isCompressed bool
Expand Down Expand Up @@ -93,6 +99,8 @@ type PluginInstance struct {
nextJParser fastjson.Parser
}

var dlErrChan chan error

func min(a, b int) int {
if a < b {
return a
Expand Down Expand Up @@ -153,6 +161,70 @@ func (p *PluginInstance) initS3() error {
return nil
}

func chunkListOrigin(orgList []listOrigin, chunkSize int) [][]listOrigin {
if len(orgList) == 0 {
return nil
}
divided := make([][]listOrigin, (len(orgList)+chunkSize-1)/chunkSize)
prev := 0
i := 0
till := len(orgList) - chunkSize
for prev < till {
next := prev + chunkSize
divided[i] = orgList[prev:next]
prev = next
i++
}
divided[i] = orgList[prev:]
return divided
}

func (oCtx *PluginInstance) listKeys(params listOrigin, startTS string, endTS string) error {
defer oCtx.s3.DownloadWg.Done()

ctx := context.Background()
// Fetch the list of keys
paginator := s3.NewListObjectsV2Paginator(oCtx.s3.client, &s3.ListObjectsV2Input{
Bucket: &oCtx.s3.bucket,
Prefix: params.prefix,
StartAfter: params.startAfter,
})

for paginator.HasMorePages() {
page, err := paginator.NextPage(ctx)
if err != nil {
dlErrChan <- err
return nil
}
for _, obj := range page.Contents {
path := obj.Key

filepathRE := regexp.MustCompile(`.*_CloudTrail_[^_]+_([^_]+)Z_`)
if startTS != "" {
matches := filepathRE.FindStringSubmatch(*path)
if matches != nil {
pathTS := matches[1]
if pathTS < startTS {
continue
}
if endTS != "" && pathTS > endTS {
continue
}
}
}

isCompressed := strings.HasSuffix(*path, ".json.gz")
if filepath.Ext(*path) != ".json" && !isCompressed {
continue
}

var fi fileInfo = fileInfo{name: *path, isCompressed: isCompressed}
oCtx.files = append(oCtx.files, fi)
}
}
return nil
}

func (oCtx *PluginInstance) openS3(input string) error {
oCtx.openMode = s3Mode

Expand All @@ -175,11 +247,6 @@ func (oCtx *PluginInstance) openS3(input string) error {
}


type listOrigin struct {
prefix *string
startAfter *string
}

var inputParams []listOrigin
ctx := context.Background()
var intervalPrefixList []string
Expand Down Expand Up @@ -291,7 +358,6 @@ func (oCtx *PluginInstance) openS3(input string) error {
}
}

filepathRE := regexp.MustCompile(`.*_CloudTrail_[^_]+_([^_]+)Z_`)
var startTS string
var endTS string

Expand All @@ -312,17 +378,18 @@ func (oCtx *PluginInstance) openS3(input string) error {
inputParams = append(inputParams, params)
}

// Would it make sense to do this concurrently?
for _, params := range inputParams {
// Fetch the list of keys
paginator := s3.NewListObjectsV2Paginator(oCtx.s3.client, &s3.ListObjectsV2Input{
Bucket: &oCtx.s3.bucket,
Prefix: params.prefix,
StartAfter: params.startAfter,
})
// Devide the inputParams array into chunks and get the keys concurently for all items in a chunk
for _, chunk := range chunkListOrigin(inputParams, oCtx.config.S3DownloadConcurrency) {
dlErrChan = make(chan error, oCtx.config.S3DownloadConcurrency)
for _, params := range chunk {
oCtx.s3.DownloadWg.Add(1)
go oCtx.listKeys(params, startTS, endTS)
}

oCtx.s3.DownloadWg.Wait()

for paginator.HasMorePages() {
page, err := paginator.NextPage(ctx)
select {
case err := <-dlErrChan:
if err != nil {
// Try friendlier error sources first.
var aErr smithy.APIError
Expand All @@ -337,30 +404,7 @@ func (oCtx *PluginInstance) openS3(input string) error {

return fmt.Errorf(PluginName + " plugin error: failed to list objects: " + err.Error())
}
for _, obj := range page.Contents {
path := obj.Key

if startTS != "" {
matches := filepathRE.FindStringSubmatch(*path)
if matches != nil {
pathTS := matches[1]
if pathTS < startTS {
continue
}
if endTS != "" && pathTS > endTS {
continue
}
}
}

isCompressed := strings.HasSuffix(*path, ".json.gz")
if filepath.Ext(*path) != ".json" && !isCompressed {
continue
}

var fi fileInfo = fileInfo{name: *path, isCompressed: isCompressed}
oCtx.files = append(oCtx.files, fi)
}
default:
}
}

Expand Down Expand Up @@ -508,8 +552,6 @@ func (oCtx *PluginInstance) openSQS(input string) error {
return oCtx.getMoreSQSFiles()
}

var dlErrChan chan error

func (oCtx *PluginInstance) s3Download(downloader *manager.Downloader, name string, dloadSlotNum int) {
defer oCtx.s3.DownloadWg.Done()

Expand Down

0 comments on commit 6d09ac5

Please sign in to comment.