Skip to content
This repository has been archived by the owner on Jul 21, 2023. It is now read-only.

Commit

Permalink
Use env credentials if no identity is provided. (#178)
Browse files Browse the repository at this point in the history
  • Loading branch information
jsha authored Nov 11, 2020
1 parent e728734 commit 1467d44
Showing 1 changed file with 26 additions and 17 deletions.
43 changes: 26 additions & 17 deletions workflow-manager/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/arn"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
aws_session "github.com/aws/aws-sdk-go/aws/session"
Expand Down Expand Up @@ -317,6 +318,21 @@ func (b *bucket) listFiles(ctx context.Context) ([]string, error) {
}
}

func webIDP(sess *aws_session.Session, identity string) (*credentials.Credentials, error) {
parsed, err := arn.Parse(identity)
if err != nil {
return nil, err
}
audience := fmt.Sprintf("sts.amazonaws.com/%s", parsed.AccountID)

stsSTS := sts.New(sess)
roleSessionName := ""
roleProvider := stscreds.NewWebIdentityRoleProviderWithToken(
stsSTS, identity, roleSessionName, tokenFetcher{audience})

return credentials.NewCredentials(roleProvider), nil
}

func (b *bucket) listFilesS3(ctx context.Context) ([]string, error) {
parts := strings.SplitN(b.bucketName, "/", 2)
if len(parts) != 2 {
Expand All @@ -329,26 +345,19 @@ func (b *bucket) listFilesS3(ctx context.Context) ([]string, error) {
return nil, fmt.Errorf("making AWS session: %w", err)
}

arnComponents := strings.Split(b.identity, ":")
if arnComponents[0] != "arn" {
return nil, fmt.Errorf("invalid AWS identity %q. Must start with \"arn:\"", b.identity)
}
if len(arnComponents) != 6 {
return nil, fmt.Errorf("invalid ARN: %q", b.identity)
var creds *credentials.Credentials
if b.identity != "" {
creds, err = webIDP(sess, b.identity)
if err != nil {
return nil, err
}
} else {
creds = credentials.NewEnvCredentials()
}
audience := fmt.Sprintf("sts.amazonaws.com/%s", arnComponents[4])

stsSTS := sts.New(sess)
roleSessionName := ""
roleProvider := stscreds.NewWebIdentityRoleProviderWithToken(
stsSTS, b.identity, roleSessionName, tokenFetcher{audience})

credentials := credentials.NewCredentials(roleProvider)
log.Printf("listing files in s3://%s as %s", bucket, b.identity)

log.Printf("listing files in s3://%s as %q", bucket, b.identity)
config := aws.NewConfig().
WithRegion(region).
WithCredentials(credentials)
WithCredentials(creds)
svc := s3.New(sess, config)
var output []string
var nextContinuationToken string = ""
Expand Down

0 comments on commit 1467d44

Please sign in to comment.