diff --git a/store/dynamodb/db.go b/store/dynamodb/db.go index 5d26a892..7c93003a 100644 --- a/store/dynamodb/db.go +++ b/store/dynamodb/db.go @@ -19,12 +19,15 @@ package dynamodb import ( "errors" + "fmt" "net/http" + "os" "time" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" "github.com/go-playground/validator/v10" "github.com/xmidt-org/argus/model" "github.com/xmidt-org/argus/store" @@ -75,10 +78,12 @@ type Config struct { GetAllLimit int // AccessKey is the AWS AccessKey credential. - AccessKey string `validate:"required"` + // AccessKey string `validate:"required"` + AccessKey string // SecretKey is the AWS SecretKey credential. - SecretKey string `validate:"required"` + // SecretKey string `validate:"required"` + SecretKey string // DisableDualStack indicates whether the connection to the DB should be // dual stack (IPv4 and IPv6). @@ -104,16 +109,41 @@ func NewDynamoDB(config Config, measures metric.Measures) (store.S, error) { return nil, err } + var creds credentials.Value + awsRegion, err := getAwsRegionForRoleBasedAccess(config) + if err != nil { + return nil, err + } + + sess, err := session.NewSession(&aws.Config{ + Region: aws.String(awsRegion)}, + ) + if err != nil { + return nil, err + } + + value, err := sess.Config.Credentials.Get() + if err != nil { + return nil, err + } + + creds = credentials.Value{ + AccessKeyID: value.AccessKeyID, + SecretAccessKey: value.SecretAccessKey, + SessionToken: value.SessionToken, + } + + fmt.Println("This is the access key: ", value.AccessKeyID) + fmt.Println("This is the secret access key: ", value.SecretAccessKey) + fmt.Println("This is the session token: ", value.SessionToken) + awsConfig := *aws.NewConfig(). WithEndpoint(config.Endpoint). WithUseDualStack(!config.DisableDualStack). WithMaxRetries(config.MaxRetries). WithCredentialsChainVerboseErrors(true). WithRegion(config.Region). - WithCredentials(credentials.NewStaticCredentialsFromCreds(credentials.Value{ - AccessKeyID: config.AccessKey, - SecretAccessKey: config.SecretKey, - })) + WithCredentials(credentials.NewStaticCredentialsFromCreds(creds)) svc, err := newService(awsConfig, "", config.Table, int64(config.GetAllLimit), &measures) if err != nil { @@ -158,3 +188,17 @@ func sanitizeError(err error) error { } return store.SanitizeError(err) } + +func getAwsRegionForRoleBasedAccess(config Config) (string, error) { + awsRegion := config.Region + + if len(awsRegion) == 0 { + awsRegion = os.Getenv("AWS_REGION") + } + + if len(awsRegion) == 0 { + return "", fmt.Errorf("%s", "Aws region is not provided") + } + + return awsRegion, nil +} diff --git a/store/dynamodb/service.go b/store/dynamodb/service.go index 816aaf68..62fec582 100644 --- a/store/dynamodb/service.go +++ b/store/dynamodb/service.go @@ -193,13 +193,13 @@ func (d *executor) Delete(key model.Key) (store.OwnableItem, *dynamodb.ConsumedC return d.getOrDelete(key, true) } -//TODO: For data >= 1MB, we'll need to handle pagination +// TODO: For data >= 1MB, we'll need to handle pagination func (d *executor) GetAll(bucket string) (map[string]store.OwnableItem, *dynamodb.ConsumedCapacity, error) { result := map[string]store.OwnableItem{} now := strconv.Itoa(int(d.now().Unix())) input := &dynamodb.QueryInput{ TableName: aws.String(d.tableName), - IndexName: aws.String("Expires-index"), + IndexName: aws.String("expires-index"), KeyConditions: map[string]*dynamodb.Condition{ "bucket": { ComparisonOperator: aws.String("EQ"),