From 7ae258d64f9edb8fabdcd66743085586d0197d67 Mon Sep 17 00:00:00 2001 From: Eugene Dementyev Date: Thu, 7 Oct 2021 14:21:20 +1300 Subject: [PATCH] Update to AWS SDK v2 Brings support for AWS SSO. --- go.mod | 5 ++- go.sum | 24 ++++++++++- lib/aws.go | 88 +++++++++++++++++++-------------------- lib/ec2connect/connect.go | 32 +++++++------- lib/util.go | 30 ++++++------- 5 files changed, 99 insertions(+), 80 deletions(-) diff --git a/go.mod b/go.mod index 1a564fe..21ed101 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,10 @@ module aws-ssh require ( github.com/apex/log v1.9.0 - github.com/aws/aws-sdk-go v1.38.35 + github.com/aws/aws-sdk-go-v2 v1.9.1 + github.com/aws/aws-sdk-go-v2/config v1.8.2 + github.com/aws/aws-sdk-go-v2/service/ec2 v1.18.0 + github.com/aws/aws-sdk-go-v2/service/ec2instanceconnect v1.5.1 github.com/go-ini/ini v1.48.0 github.com/gopherjs/gopherjs v0.0.0-20181103185306-d547d1d9531e // indirect github.com/hashicorp/go-multierror v1.1.1 diff --git a/go.sum b/go.sum index f1fee64..72be8a9 100644 --- a/go.sum +++ b/go.sum @@ -133,8 +133,29 @@ github.com/aws/aws-sdk-go v1.20.6/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN github.com/aws/aws-sdk-go v1.23.20/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= github.com/aws/aws-sdk-go v1.25.37/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= github.com/aws/aws-sdk-go v1.36.30/go.mod h1:hcU610XS61/+aQV88ixoOzUoG7v3b31pl2zKMmprdro= -github.com/aws/aws-sdk-go v1.38.35 h1:7AlAO0FC+8nFjxiGKEmq0QLpiA8/XFr6eIxgRTwkdTg= github.com/aws/aws-sdk-go v1.38.35/go.mod h1:hcU610XS61/+aQV88ixoOzUoG7v3b31pl2zKMmprdro= +github.com/aws/aws-sdk-go-v2 v1.9.1 h1:ZbovGV/qo40nrOJ4q8G33AGICzaPI45FHQWJ9650pF4= +github.com/aws/aws-sdk-go-v2 v1.9.1/go.mod h1:cK/D0BBs0b/oWPIcX/Z/obahJK1TT7IPVjy53i/mX/4= +github.com/aws/aws-sdk-go-v2/config v1.8.2 h1:Dqy4ySXFmulRmZhfynm/5CD4Y6aXiTVhDtXLIuUe/r0= +github.com/aws/aws-sdk-go-v2/config v1.8.2/go.mod h1:r0bkX9NyuCuf28qVcsEMtpAQibT7gA1Q0gzkjvgJdLU= +github.com/aws/aws-sdk-go-v2/credentials v1.4.2 h1:8kVE4Og6wlhVrMGiORQ3p9gRj2exjzhFRB+QzWBUa5Q= +github.com/aws/aws-sdk-go-v2/credentials v1.4.2/go.mod h1:9Sp6u121/f0NnvHyhG7dgoYeUTEFC2vsvJqJ6wXpkaI= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.5.1 h1:Nm+BxqBtT0r+AnD6byGMCGT4Km0QwHBy8mAYptNPXY4= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.5.1/go.mod h1:W1ldHfsgeGlKpJ4xZMKZUI6Wmp6EAstU7PxnhbXWWrI= +github.com/aws/aws-sdk-go-v2/internal/ini v1.2.3 h1:NnXJXUz7oihrSlPKEM0yZ19b+7GQ47MX/LluLlEyE/Y= +github.com/aws/aws-sdk-go-v2/internal/ini v1.2.3/go.mod h1:EES9ToeC3h063zCFDdqWGnARExNdULPaBvARm1FLwxA= +github.com/aws/aws-sdk-go-v2/service/ec2 v1.18.0 h1:5wWtSfYRWgkpKKMW4yJ5llzI9s24Fls7Pv7uw2BiYbk= +github.com/aws/aws-sdk-go-v2/service/ec2 v1.18.0/go.mod h1:d8R2f1hFcknkA3MW4SeExwEua2KpR+dhSrwWlnlwe5Q= +github.com/aws/aws-sdk-go-v2/service/ec2instanceconnect v1.5.1 h1:Nr9llH7oJN3drO0lQgCganTN+3I+AzMTGRPzKo30X3U= +github.com/aws/aws-sdk-go-v2/service/ec2instanceconnect v1.5.1/go.mod h1:iHBeiwp3Xfp7NO//QLJIlk4j5zfH0APBzqpQMSGnCAA= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.3.1 h1:APEjhKZLFlNVLATnA/TJyA+w1r/xd5r5ACWBDZ9aIvc= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.3.1/go.mod h1:Ve+eJOx9UWaT/lMVebnFhDhO49fSLVedHoA82+Rqme0= +github.com/aws/aws-sdk-go-v2/service/sso v1.4.1 h1:RfgQyv3bFT2Js6XokcrNtTjQ6wAVBRpoCgTFsypihHA= +github.com/aws/aws-sdk-go-v2/service/sso v1.4.1/go.mod h1:ycPdbJZlM0BLhuBnd80WX9PucWPG88qps/2jl9HugXs= +github.com/aws/aws-sdk-go-v2/service/sts v1.7.1 h1:7ce9ugapSgBapwLhg7AJTqKW5U92VRX3vX65k2tsB+g= +github.com/aws/aws-sdk-go-v2/service/sts v1.7.1/go.mod h1:r1i8QwKPzwByXqZb3POQfBs7jozrdnHz8PVbsvyx73w= +github.com/aws/smithy-go v1.8.0 h1:AEwwwXQZtUwP5Mz506FeXXrKBe0jA8gVM+1gEcSRooc= +github.com/aws/smithy-go v1.8.0/go.mod h1:SObp3lf9smib00L/v3U2eAKG8FyQ7iLrJnQiAmR5n+E= github.com/aybabtme/rgbterm v0.0.0-20170906152045-cc83f3b3ce59/go.mod h1:q/89r3U2H7sSsE2t6Kca0lfwTK8JdoNGS/yzM/4iH5I= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= @@ -1170,6 +1191,7 @@ golang.org/x/tools v0.1.3/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= google.golang.org/api v0.5.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE= diff --git a/lib/aws.go b/lib/aws.go index 79f79e8..a690fc2 100644 --- a/lib/aws.go +++ b/lib/aws.go @@ -1,13 +1,15 @@ package lib import ( + "context" "fmt" "sort" "github.com/apex/log" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ec2/types" multierror "github.com/hashicorp/go-multierror" linq "gopkg.in/ahmetb/go-linq.v3" ) @@ -17,7 +19,7 @@ import ( type profileSummary struct { ProfileConfig - Instances []*ec2.Instance + Instances []types.Instance } // ProcessedProfileSummary represents profile summary @@ -28,21 +30,6 @@ type ProcessedProfileSummary struct { SSHEntries []SSHEntry } -func makeSession(profile string) (*session.Session, error) { - log.Debugf("Creating session for %s", profile) - // create AWS session - localSession, err := session.NewSessionWithOptions(session.Options{ - Config: aws.Config{}, - - SharedConfigState: session.SharedConfigEnable, - Profile: profile, - }) - if err != nil { - return nil, fmt.Errorf("can't get aws session") - } - return localSession, nil -} - // TraverseProfiles goes through all profiles and returns a list of ProcessedProfileSummary func TraverseProfiles(profiles []ProfileConfig, noProfilePrefix bool) ([]ProcessedProfileSummary, error) { log.Debugf("Traversing through %d profiles", len(profiles)) @@ -86,28 +73,28 @@ func TraverseProfiles(profiles []ProfileConfig, noProfilePrefix bool) ([]Process linq.From(summary.Instances).OrderBy(instanceNameSorter). // sort by name first ThenBy(instanceLaunchTimeSorter). // then by launch time GroupBy(func(i interface{}) interface{} { // and then group by vpc - vpcID := i.(*ec2.Instance).VpcId - return aws.StringValue(vpcID) + vpcID := i.(*types.Instance).VpcId + return aws.ToString(vpcID) }, func(i interface{}) interface{} { - return i.(*ec2.Instance) + return i.(*types.Instance) }).ToSlice(&vpcInstances) - var commonBastions []*ec2.Instance + var commonBastions []*types.Instance linq.From(summary.Instances).OrderBy(instanceNameSorter). // sort by name first ThenBy(instanceLaunchTimeSorter). // then by launch time Where( func(f interface{}) bool { - return isBastionFromTags(f.(*ec2.Instance).Tags, true) // check for global tag as well + return isBastionFromTags(f.(*types.Instance).Tags, true) // check for global tag as well }, ).ToSlice(&commonBastions) ctx.Debugf("Found %d common (global) bastions", len(commonBastions)) for _, vpcGroup := range vpcInstances { // take the instances grouped by vpc and iterate - var vpcBastions []*ec2.Instance + var vpcBastions []*types.Instance linq.From(vpcGroup.Group).Where( func(f interface{}) bool { - return isBastionFromTags(f.(*ec2.Instance).Tags, false) // "false" means don't check for global tag + return isBastionFromTags(f.(*types.Instance).Tags, false) // "false" means don't check for global tag }, ).ToSlice(&vpcBastions) @@ -115,10 +102,10 @@ func TraverseProfiles(profiles []ProfileConfig, noProfilePrefix bool) ([]Process var nameInstances []linq.Group linq.From(vpcGroup.Group).GroupBy(func(i interface{}) interface{} { // now group them by name - instanceName := getNameFromTags(i.(*ec2.Instance).Tags) + instanceName := getNameFromTags(i.(*types.Instance).Tags) return instanceName }, func(i interface{}) interface{} { - return i.(*ec2.Instance) + return i.(*types.Instance) }).ToSlice(&nameInstances) // now we have instances, grouped by vpc and name @@ -126,9 +113,9 @@ func TraverseProfiles(profiles []ProfileConfig, noProfilePrefix bool) ([]Process instanceName := nameGroup.Key.(string) for n, instance := range nameGroup.Group { - instance := instance.(*ec2.Instance) + instance := instance.(*types.Instance) var entry = SSHEntry{ - InstanceID: aws.StringValue(instance.InstanceId), + InstanceID: aws.ToString(instance.InstanceId), ProfileConfig: ProfileConfig{ Name: summary.Name, Region: summary.Region, @@ -143,14 +130,14 @@ func TraverseProfiles(profiles []ProfileConfig, noProfilePrefix bool) ([]Process if bastion == nil { // then try common ones bastion = findBestBastion(instanceName, commonBastions) } - entry.Address = aws.StringValue(instance.PrivateIpAddress) // get the private address first as we always have one - if bastion != nil { // get private address and add proxyhost, which is the bastion ip + entry.Address = aws.ToString(instance.PrivateIpAddress) // get the private address first as we always have one + if bastion != nil { // get private address and add proxyhost, which is the bastion ip // refer to the bastion by its instance ID // which we should have a record for - entry.ProxyJump = aws.StringValue(bastion.InstanceId) + entry.ProxyJump = aws.ToString(bastion.InstanceId) } else { // get public IP if we have one - if publicIP := aws.StringValue(instance.PublicIpAddress); publicIP != "" { - entry.Address = aws.StringValue(instance.PublicIpAddress) + if publicIP := aws.ToString(instance.PublicIpAddress); publicIP != "" { + entry.Address = aws.ToString(instance.PublicIpAddress) } } var instanceIndex string @@ -187,7 +174,9 @@ func TraverseProfiles(profiles []ProfileConfig, noProfilePrefix bool) ([]Process // DescribeProfile describes the specified profile func DescribeProfile(profile ProfileConfig, sum chan profileSummary, errChan chan error) { - awsSession, err := makeSession(profile.Name) + cfg, err := config.LoadDefaultConfig(context.TODO(), + config.WithSharedConfigProfile(profile.Name)) + if err != nil { errChan <- fmt.Errorf("Couldn't create session for '%s': %s", profile.Name, err) return @@ -196,29 +185,36 @@ func DescribeProfile(profile ProfileConfig, sum chan profileSummary, errChan cha profileSummary := profileSummary{ ProfileConfig: ProfileConfig{ Name: profile.Name, - Region: aws.StringValue(awsSession.Config.Region), + Region: cfg.Region, Domain: profile.Domain, }, } - svc := ec2.New(awsSession) + svc := ec2.NewFromConfig(cfg) input := &ec2.DescribeInstancesInput{ - Filters: []*ec2.Filter{ + Filters: []types.Filter{ { Name: aws.String("instance-state-name"), - Values: aws.StringSlice([]string{ec2.InstanceStateNameRunning}), + Values: []string{string(types.InstanceStateNameRunning)}, }, }, } - err = svc.DescribeInstancesPages(input, func(result *ec2.DescribeInstancesOutput, lastPage bool) bool { - for _, reservation := range result.Reservations { - for _, instance := range reservation.Instances { - profileSummary.Instances = append(profileSummary.Instances, instance) + paginator := ec2.NewDescribeInstancesPaginator(svc, input) + for paginator.HasMorePages() { + result, subErr := paginator.NextPage(context.TODO()) + if subErr != nil { + err = subErr + break + } else { + for _, reservation := range result.Reservations { + for _, instance := range reservation.Instances { + profileSummary.Instances = append(profileSummary.Instances, instance) + } } } - return false - }) + } + if err != nil { errChan <- fmt.Errorf("Can't get full information for '%s': %s", profile, err) } else { diff --git a/lib/ec2connect/connect.go b/lib/ec2connect/connect.go index 643d881..c37c183 100644 --- a/lib/ec2connect/connect.go +++ b/lib/ec2connect/connect.go @@ -2,6 +2,7 @@ package ec2connect import ( "aws-ssh/lib" + "context" "fmt" "net" @@ -11,10 +12,10 @@ import ( "syscall" "github.com/apex/log" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/aws/aws-sdk-go/service/ec2instanceconnect" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/ec2instanceconnect" "golang.org/x/crypto/ssh/agent" ) @@ -97,18 +98,15 @@ func ConnectEC2(sshEntries lib.SSHEntries, sshConfigPath string, args []string) // and returns the public (or private if public doesn't exist) address of the EC2 instance func pushEC2Connect(profile, instanceID, instanceUser, pubKey string) (string, string, error) { ctx := log.WithField("instance_id", instanceID) - localSession, err := session.NewSessionWithOptions(session.Options{ - Config: aws.Config{}, + cfg, err := config.LoadDefaultConfig(context.TODO(), + config.WithSharedConfigProfile(profile)) - SharedConfigState: session.SharedConfigEnable, - Profile: profile, - }) if err != nil { return "", "", fmt.Errorf("can't get aws session: %s", err) } - ec2Svc := ec2.New(localSession) - ec2Result, err := ec2Svc.DescribeInstances(&ec2.DescribeInstancesInput{ - InstanceIds: aws.StringSlice([]string{instanceID}), + ec2Svc := ec2.NewFromConfig(cfg) + ec2Result, err := ec2Svc.DescribeInstances(context.TODO(), &ec2.DescribeInstancesInput{ + InstanceIds: []string{instanceID}, }) if err != nil { return "", "", fmt.Errorf("can't get ec2 instance: %s", err) @@ -119,7 +117,7 @@ func pushEC2Connect(profile, instanceID, instanceUser, pubKey string) (string, s } ec2Instance := ec2Result.Reservations[0].Instances[0] - ec2ICSvc := ec2instanceconnect.New(localSession) + ec2ICSvc := ec2instanceconnect.NewFromConfig(cfg) // no username has been provided, so we try to get it fom the instance tag first if instanceUser == "" { @@ -136,7 +134,7 @@ func pushEC2Connect(profile, instanceID, instanceUser, pubKey string) (string, s ctx.WithField("user", instanceUser).Info("pushing SSH key...") - if _, err := ec2ICSvc.SendSSHPublicKey(&ec2instanceconnect.SendSSHPublicKeyInput{ + if _, err := ec2ICSvc.SendSSHPublicKey(context.TODO(), &ec2instanceconnect.SendSSHPublicKeyInput{ InstanceId: ec2Instance.InstanceId, InstanceOSUser: aws.String(instanceUser), AvailabilityZone: ec2Instance.Placement.AvailabilityZone, @@ -144,9 +142,9 @@ func pushEC2Connect(profile, instanceID, instanceUser, pubKey string) (string, s }); err != nil { return "", "", fmt.Errorf("can't push ssh key: %s", err) } - var address = aws.StringValue(ec2Instance.PrivateIpAddress) - if aws.StringValue(ec2Instance.PublicIpAddress) != "" { - address = aws.StringValue(ec2Instance.PublicIpAddress) + var address = aws.ToString(ec2Instance.PrivateIpAddress) + if aws.ToString(ec2Instance.PublicIpAddress) != "" { + address = aws.ToString(ec2Instance.PublicIpAddress) } return address, instanceUser, nil } diff --git a/lib/util.go b/lib/util.go index c68bf07..d7e64b5 100644 --- a/lib/util.go +++ b/lib/util.go @@ -5,15 +5,15 @@ import ( "sort" "strings" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2/types" ) const bastionCanonicalName = "bastion" var sanitiser = regexp.MustCompile("[\\s-]+") -func getTagValue(tag string, tags []*ec2.Tag, caseInsensitive ...bool) string { +func getTagValue(tag string, tags []types.Tag, caseInsensitive ...bool) string { if len(caseInsensitive) > 0 { if caseInsensitive[0] { tag = strings.ToLower(tag) @@ -21,8 +21,8 @@ func getTagValue(tag string, tags []*ec2.Tag, caseInsensitive ...bool) string { } for _, subTag := range tags { - if aws.StringValue(subTag.Key) == tag { - return aws.StringValue(subTag.Value) + if aws.ToString(subTag.Key) == tag { + return aws.ToString(subTag.Value) } } @@ -30,31 +30,31 @@ func getTagValue(tag string, tags []*ec2.Tag, caseInsensitive ...bool) string { } -func getNameFromTags(tags []*ec2.Tag) string { +func getNameFromTags(tags []types.Tag) string { return strings.ToLower(getTagValue("Name", tags)) } -func getPortFromTags(tags []*ec2.Tag) string { +func getPortFromTags(tags []types.Tag) string { return strings.ToLower(getTagValue("x-aws-ssh-port", tags)) } // GetUserFromTags gets the ec2 username from tags -func GetUserFromTags(tags []*ec2.Tag) string { +func GetUserFromTags(tags []types.Tag) string { return strings.ToLower(getTagValue("x-aws-ssh-user", tags)) } -func isBastionFromTags(tags []*ec2.Tag, checkGlobal bool) bool { +func isBastionFromTags(tags []types.Tag, checkGlobal bool) bool { if len(tags) > 0 { var name string var global bool for _, tag := range tags { - switch aws.StringValue(tag.Key) { + switch aws.ToString(tag.Key) { case "Name": - name = strings.ToLower(aws.StringValue(tag.Value)) + name = strings.ToLower(aws.ToString(tag.Value)) case "Global", "x-aws-ssh-global": { - value := strings.ToLower(aws.StringValue(tag.Value)) + value := strings.ToLower(aws.ToString(tag.Value)) if value == "yes" || value == "true" || value == "1" { global = true } @@ -85,7 +85,7 @@ func (w weights) Len() int { return len(w) } func (w weights) Less(i, j int) bool { return w[i].Weight < w[j].Weight } func (w weights) Swap(i, j int) { w[i], w[j] = w[j], w[i] } -func findBestBastion(instanceName string, bastions []*ec2.Instance) *ec2.Instance { +func findBestBastion(instanceName string, bastions []*types.Instance) *types.Instance { // skip instances with bastionCanonicalName in name if !strings.Contains(instanceName, bastionCanonicalName) && len(bastions) > 0 { if len(bastions) == 1 { @@ -122,11 +122,11 @@ func getInstanceCanonicalName(profile, instanceName, instanceIndex string) strin } func instanceLaunchTimeSorter(i interface{}) interface{} { // sorts by launch time - launched := aws.TimeValue(i.(*ec2.Instance).LaunchTime) + launched := aws.ToTime(i.(*types.Instance).LaunchTime) return launched.Unix() } func instanceNameSorter(i interface{}) interface{} { // sort by instance name - instanceName := getNameFromTags(i.(*ec2.Instance).Tags) + instanceName := getNameFromTags(i.(*types.Instance).Tags) return instanceName }