diff --git a/agent/managedInstances/rolecreds/managedInstances_role_provider_test.go b/agent/managedInstances/rolecreds/managedInstances_role_provider_test.go index 83e2c0ba1..576547107 100644 --- a/agent/managedInstances/rolecreds/managedInstances_role_provider_test.go +++ b/agent/managedInstances/rolecreds/managedInstances_role_provider_test.go @@ -21,6 +21,8 @@ import ( "testing" "time" + "github.com/aws/amazon-ssm-agent/agent/fileutil" + "github.com/aws/amazon-ssm-agent/agent/managedInstances/sharedCredentials" "github.com/aws/aws-sdk-go/service/ssm" "github.com/stretchr/testify/assert" ) @@ -29,8 +31,17 @@ var ( accessKeyID = "accessKeyID" secretAccessKey = "secretAccessKey" sessionToken = "sessionToken" + region = "us-east-1" ) +func cleanupCredFile() { + if credPath, err := sharedCredentials.Filename(); err == nil { + if credPath != "" && fileutil.Exists(credPath) { + fileutil.DeleteFile(credPath) + } + } +} + func TestRetrieve_ShouldReturnValidToken(t *testing.T) { updateKeyPair := false tokenExpirationDate := time.Now().Add(1 * time.Hour) @@ -51,6 +62,8 @@ func TestRetrieve_ShouldReturnValidToken(t *testing.T) { assert.Equal(t, accessKeyID, cred.AccessKeyID) assert.Equal(t, secretAccessKey, cred.SecretAccessKey) assert.Equal(t, sessionToken, cred.SessionToken) + + cleanupCredFile() } func TestRetrieve_ShouldUpdateKeyPair(t *testing.T) { @@ -60,6 +73,7 @@ func TestRetrieve_ShouldUpdateKeyPair(t *testing.T) { publicKey: "publicKey", privateKey: "privateKey", keyType: "Rsa", + region: "us-east-1", } client := &RsaSignedServiceStub{ roleResponse: ssm.RequestManagedInstanceRoleTokenOutput{ @@ -76,6 +90,7 @@ func TestRetrieve_ShouldUpdateKeyPair(t *testing.T) { _, err := testProvider.Retrieve() assert.NoError(t, err) assert.True(t, client.updateCalled) + cleanupCredFile() } func TestRetrieve_ShouldFailOnError(t *testing.T) { diff --git a/agent/managedInstances/sharedCredentials/shared_Credentials.go b/agent/managedInstances/sharedCredentials/shared_Credentials.go index 77b2ed7f6..bb9e6050d 100644 --- a/agent/managedInstances/sharedCredentials/shared_Credentials.go +++ b/agent/managedInstances/sharedCredentials/shared_Credentials.go @@ -20,6 +20,7 @@ import ( "path/filepath" "github.com/aws/amazon-ssm-agent/agent/fileutil" + "github.com/aws/amazon-ssm-agent/agent/managedInstances/registration" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/go-ini/ini" ) @@ -29,12 +30,12 @@ const ( awsAccessKeyID = "aws_access_key_id" awsSecretAccessKey = "aws_secret_access_key" awsSessionToken = "aws_session_token" + awsRegion = "region" ) -// filename returns the filename to use to read AWS shared credentials. -// +// Filename returns the filename to use to read AWS shared credentials. // Will return an error if the user's home directory path cannot be found. -func filename() (string, error) { +func Filename() (string, error) { if credPath := os.Getenv("AWS_SHARED_CREDENTIALS_FILE"); credPath != "" { return credPath, nil } @@ -68,7 +69,7 @@ func Store(accessKeyID, secretAccessKey, sessionToken, profile string) error { profile = defaultProfile } - credPath, err := filename() + credPath, err := Filename() if err != nil { return err } @@ -98,6 +99,18 @@ func Store(accessKeyID, secretAccessKey, sessionToken, profile string) error { iniProfile.Key(awsSessionToken).SetValue(sessionToken) + // Save the instance's region to the profile so that the FallbackRegionFactory can find it. + // Scripts that use the .NET Cmdlets and aws command line tools will automatically detect + // the AWS Region from the EC2 instance profile, however, this is not the case for on-prem + // servers, since they don't have the EC2 Metadata service. By adding the Region to the + // shared credentials file, the SDK will be able to discover the region automatically. + // This will ensure that scripts that run on on-prem servers will run the same way as + // they would on EC2 instances, without any modification. + region := registration.Region() + if region != "" { + iniProfile.Key(awsRegion).SetValue(region) + } + err = config.SaveTo(credPath) if err != nil { return awserr.New("SharedCredentialsStore", "failed to save profile", err) diff --git a/agent/managedInstances/sharedCredentials/shared_credentials_integ_test.go b/agent/managedInstances/sharedCredentials/shared_credentials_integ_test.go index 90df97880..bd5942d9f 100644 --- a/agent/managedInstances/sharedCredentials/shared_credentials_integ_test.go +++ b/agent/managedInstances/sharedCredentials/shared_credentials_integ_test.go @@ -30,6 +30,7 @@ const ( accessKey = "DummyAccessKey" accessSecretKey = "DummyAccessSecretKey" token = "DummyToken" + region = "us-east-1" profile = "DummyProfile" testFilePath = "example.ini" ) diff --git a/agent/s3util/s3util.go b/agent/s3util/s3util.go index bd1b4fce2..c50b68a9b 100644 --- a/agent/s3util/s3util.go +++ b/agent/s3util/s3util.go @@ -116,8 +116,8 @@ func (u *AmazonS3Util) S3Upload(log log.T, bucketName string, objectKey string, func GetBucketRegion(log log.T, bucketName string, httpProvider HttpProvider) (region string) { instanceRegion, err := getRegion() if err != nil { - log.Error("Cannot get the current instance region information") - return instanceRegion // Default + log.Error(fmt.Errorf("Cannot get the current instance region information: %v", err)) + return "us-east-1" // Default } log.Infof("Instance region is %v", instanceRegion)