Skip to content

Commit

Permalink
Update to AWS SDK v2
Browse files Browse the repository at this point in the history
Brings support for AWS SSO.
  • Loading branch information
Eugene Dementyev authored and ekini committed Oct 7, 2021
1 parent aec199a commit 7ae258d
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 80 deletions.
5 changes: 4 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@ module aws-ssh

require (
github.com/apex/log v1.9.0
github.com/aws/aws-sdk-go v1.38.35
github.com/aws/aws-sdk-go-v2 v1.9.1
github.com/aws/aws-sdk-go-v2/config v1.8.2
github.com/aws/aws-sdk-go-v2/service/ec2 v1.18.0
github.com/aws/aws-sdk-go-v2/service/ec2instanceconnect v1.5.1
github.com/go-ini/ini v1.48.0
github.com/gopherjs/gopherjs v0.0.0-20181103185306-d547d1d9531e // indirect
github.com/hashicorp/go-multierror v1.1.1
Expand Down
24 changes: 23 additions & 1 deletion go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,29 @@ github.com/aws/aws-sdk-go v1.20.6/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN
github.com/aws/aws-sdk-go v1.23.20/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo=
github.com/aws/aws-sdk-go v1.25.37/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo=
github.com/aws/aws-sdk-go v1.36.30/go.mod h1:hcU610XS61/+aQV88ixoOzUoG7v3b31pl2zKMmprdro=
github.com/aws/aws-sdk-go v1.38.35 h1:7AlAO0FC+8nFjxiGKEmq0QLpiA8/XFr6eIxgRTwkdTg=
github.com/aws/aws-sdk-go v1.38.35/go.mod h1:hcU610XS61/+aQV88ixoOzUoG7v3b31pl2zKMmprdro=
github.com/aws/aws-sdk-go-v2 v1.9.1 h1:ZbovGV/qo40nrOJ4q8G33AGICzaPI45FHQWJ9650pF4=
github.com/aws/aws-sdk-go-v2 v1.9.1/go.mod h1:cK/D0BBs0b/oWPIcX/Z/obahJK1TT7IPVjy53i/mX/4=
github.com/aws/aws-sdk-go-v2/config v1.8.2 h1:Dqy4ySXFmulRmZhfynm/5CD4Y6aXiTVhDtXLIuUe/r0=
github.com/aws/aws-sdk-go-v2/config v1.8.2/go.mod h1:r0bkX9NyuCuf28qVcsEMtpAQibT7gA1Q0gzkjvgJdLU=
github.com/aws/aws-sdk-go-v2/credentials v1.4.2 h1:8kVE4Og6wlhVrMGiORQ3p9gRj2exjzhFRB+QzWBUa5Q=
github.com/aws/aws-sdk-go-v2/credentials v1.4.2/go.mod h1:9Sp6u121/f0NnvHyhG7dgoYeUTEFC2vsvJqJ6wXpkaI=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.5.1 h1:Nm+BxqBtT0r+AnD6byGMCGT4Km0QwHBy8mAYptNPXY4=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.5.1/go.mod h1:W1ldHfsgeGlKpJ4xZMKZUI6Wmp6EAstU7PxnhbXWWrI=
github.com/aws/aws-sdk-go-v2/internal/ini v1.2.3 h1:NnXJXUz7oihrSlPKEM0yZ19b+7GQ47MX/LluLlEyE/Y=
github.com/aws/aws-sdk-go-v2/internal/ini v1.2.3/go.mod h1:EES9ToeC3h063zCFDdqWGnARExNdULPaBvARm1FLwxA=
github.com/aws/aws-sdk-go-v2/service/ec2 v1.18.0 h1:5wWtSfYRWgkpKKMW4yJ5llzI9s24Fls7Pv7uw2BiYbk=
github.com/aws/aws-sdk-go-v2/service/ec2 v1.18.0/go.mod h1:d8R2f1hFcknkA3MW4SeExwEua2KpR+dhSrwWlnlwe5Q=
github.com/aws/aws-sdk-go-v2/service/ec2instanceconnect v1.5.1 h1:Nr9llH7oJN3drO0lQgCganTN+3I+AzMTGRPzKo30X3U=
github.com/aws/aws-sdk-go-v2/service/ec2instanceconnect v1.5.1/go.mod h1:iHBeiwp3Xfp7NO//QLJIlk4j5zfH0APBzqpQMSGnCAA=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.3.1 h1:APEjhKZLFlNVLATnA/TJyA+w1r/xd5r5ACWBDZ9aIvc=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.3.1/go.mod h1:Ve+eJOx9UWaT/lMVebnFhDhO49fSLVedHoA82+Rqme0=
github.com/aws/aws-sdk-go-v2/service/sso v1.4.1 h1:RfgQyv3bFT2Js6XokcrNtTjQ6wAVBRpoCgTFsypihHA=
github.com/aws/aws-sdk-go-v2/service/sso v1.4.1/go.mod h1:ycPdbJZlM0BLhuBnd80WX9PucWPG88qps/2jl9HugXs=
github.com/aws/aws-sdk-go-v2/service/sts v1.7.1 h1:7ce9ugapSgBapwLhg7AJTqKW5U92VRX3vX65k2tsB+g=
github.com/aws/aws-sdk-go-v2/service/sts v1.7.1/go.mod h1:r1i8QwKPzwByXqZb3POQfBs7jozrdnHz8PVbsvyx73w=
github.com/aws/smithy-go v1.8.0 h1:AEwwwXQZtUwP5Mz506FeXXrKBe0jA8gVM+1gEcSRooc=
github.com/aws/smithy-go v1.8.0/go.mod h1:SObp3lf9smib00L/v3U2eAKG8FyQ7iLrJnQiAmR5n+E=
github.com/aybabtme/rgbterm v0.0.0-20170906152045-cc83f3b3ce59/go.mod h1:q/89r3U2H7sSsE2t6Kca0lfwTK8JdoNGS/yzM/4iH5I=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8=
Expand Down Expand Up @@ -1170,6 +1191,7 @@ golang.org/x/tools v0.1.3/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=
google.golang.org/api v0.5.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=
Expand Down
88 changes: 42 additions & 46 deletions lib/aws.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package lib

import (
"context"
"fmt"
"sort"

"github.com/apex/log"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/ec2/types"
multierror "github.com/hashicorp/go-multierror"
linq "gopkg.in/ahmetb/go-linq.v3"
)
Expand All @@ -17,7 +19,7 @@ import (
type profileSummary struct {
ProfileConfig

Instances []*ec2.Instance
Instances []types.Instance
}

// ProcessedProfileSummary represents profile summary
Expand All @@ -28,21 +30,6 @@ type ProcessedProfileSummary struct {
SSHEntries []SSHEntry
}

func makeSession(profile string) (*session.Session, error) {
log.Debugf("Creating session for %s", profile)
// create AWS session
localSession, err := session.NewSessionWithOptions(session.Options{
Config: aws.Config{},

SharedConfigState: session.SharedConfigEnable,
Profile: profile,
})
if err != nil {
return nil, fmt.Errorf("can't get aws session")
}
return localSession, nil
}

// TraverseProfiles goes through all profiles and returns a list of ProcessedProfileSummary
func TraverseProfiles(profiles []ProfileConfig, noProfilePrefix bool) ([]ProcessedProfileSummary, error) {
log.Debugf("Traversing through %d profiles", len(profiles))
Expand Down Expand Up @@ -86,49 +73,49 @@ func TraverseProfiles(profiles []ProfileConfig, noProfilePrefix bool) ([]Process
linq.From(summary.Instances).OrderBy(instanceNameSorter). // sort by name first
ThenBy(instanceLaunchTimeSorter). // then by launch time
GroupBy(func(i interface{}) interface{} { // and then group by vpc
vpcID := i.(*ec2.Instance).VpcId
return aws.StringValue(vpcID)
vpcID := i.(*types.Instance).VpcId
return aws.ToString(vpcID)
}, func(i interface{}) interface{} {
return i.(*ec2.Instance)
return i.(*types.Instance)
}).ToSlice(&vpcInstances)

var commonBastions []*ec2.Instance
var commonBastions []*types.Instance
linq.From(summary.Instances).OrderBy(instanceNameSorter). // sort by name first
ThenBy(instanceLaunchTimeSorter). // then by launch time
Where(
func(f interface{}) bool {
return isBastionFromTags(f.(*ec2.Instance).Tags, true) // check for global tag as well
return isBastionFromTags(f.(*types.Instance).Tags, true) // check for global tag as well
},
).ToSlice(&commonBastions)

ctx.Debugf("Found %d common (global) bastions", len(commonBastions))

for _, vpcGroup := range vpcInstances { // take the instances grouped by vpc and iterate
var vpcBastions []*ec2.Instance
var vpcBastions []*types.Instance
linq.From(vpcGroup.Group).Where(
func(f interface{}) bool {
return isBastionFromTags(f.(*ec2.Instance).Tags, false) // "false" means don't check for global tag
return isBastionFromTags(f.(*types.Instance).Tags, false) // "false" means don't check for global tag
},
).ToSlice(&vpcBastions)

ctx.WithField("vpc", vpcGroup.Key).Debugf("Found %d bastions", len(vpcBastions))

var nameInstances []linq.Group
linq.From(vpcGroup.Group).GroupBy(func(i interface{}) interface{} { // now group them by name
instanceName := getNameFromTags(i.(*ec2.Instance).Tags)
instanceName := getNameFromTags(i.(*types.Instance).Tags)
return instanceName
}, func(i interface{}) interface{} {
return i.(*ec2.Instance)
return i.(*types.Instance)
}).ToSlice(&nameInstances)

// now we have instances, grouped by vpc and name
for _, nameGroup := range nameInstances {
instanceName := nameGroup.Key.(string)

for n, instance := range nameGroup.Group {
instance := instance.(*ec2.Instance)
instance := instance.(*types.Instance)
var entry = SSHEntry{
InstanceID: aws.StringValue(instance.InstanceId),
InstanceID: aws.ToString(instance.InstanceId),
ProfileConfig: ProfileConfig{
Name: summary.Name,
Region: summary.Region,
Expand All @@ -143,14 +130,14 @@ func TraverseProfiles(profiles []ProfileConfig, noProfilePrefix bool) ([]Process
if bastion == nil { // then try common ones
bastion = findBestBastion(instanceName, commonBastions)
}
entry.Address = aws.StringValue(instance.PrivateIpAddress) // get the private address first as we always have one
if bastion != nil { // get private address and add proxyhost, which is the bastion ip
entry.Address = aws.ToString(instance.PrivateIpAddress) // get the private address first as we always have one
if bastion != nil { // get private address and add proxyhost, which is the bastion ip
// refer to the bastion by its instance ID
// which we should have a record for
entry.ProxyJump = aws.StringValue(bastion.InstanceId)
entry.ProxyJump = aws.ToString(bastion.InstanceId)
} else { // get public IP if we have one
if publicIP := aws.StringValue(instance.PublicIpAddress); publicIP != "" {
entry.Address = aws.StringValue(instance.PublicIpAddress)
if publicIP := aws.ToString(instance.PublicIpAddress); publicIP != "" {
entry.Address = aws.ToString(instance.PublicIpAddress)
}
}
var instanceIndex string
Expand Down Expand Up @@ -187,7 +174,9 @@ func TraverseProfiles(profiles []ProfileConfig, noProfilePrefix bool) ([]Process

// DescribeProfile describes the specified profile
func DescribeProfile(profile ProfileConfig, sum chan profileSummary, errChan chan error) {
awsSession, err := makeSession(profile.Name)
cfg, err := config.LoadDefaultConfig(context.TODO(),
config.WithSharedConfigProfile(profile.Name))

if err != nil {
errChan <- fmt.Errorf("Couldn't create session for '%s': %s", profile.Name, err)
return
Expand All @@ -196,29 +185,36 @@ func DescribeProfile(profile ProfileConfig, sum chan profileSummary, errChan cha
profileSummary := profileSummary{
ProfileConfig: ProfileConfig{
Name: profile.Name,
Region: aws.StringValue(awsSession.Config.Region),
Region: cfg.Region,
Domain: profile.Domain,
},
}

svc := ec2.New(awsSession)
svc := ec2.NewFromConfig(cfg)
input := &ec2.DescribeInstancesInput{
Filters: []*ec2.Filter{
Filters: []types.Filter{
{
Name: aws.String("instance-state-name"),
Values: aws.StringSlice([]string{ec2.InstanceStateNameRunning}),
Values: []string{string(types.InstanceStateNameRunning)},
},
},
}

err = svc.DescribeInstancesPages(input, func(result *ec2.DescribeInstancesOutput, lastPage bool) bool {
for _, reservation := range result.Reservations {
for _, instance := range reservation.Instances {
profileSummary.Instances = append(profileSummary.Instances, instance)
paginator := ec2.NewDescribeInstancesPaginator(svc, input)
for paginator.HasMorePages() {
result, subErr := paginator.NextPage(context.TODO())
if subErr != nil {
err = subErr
break
} else {
for _, reservation := range result.Reservations {
for _, instance := range reservation.Instances {
profileSummary.Instances = append(profileSummary.Instances, instance)
}
}
}
return false
})
}

if err != nil {
errChan <- fmt.Errorf("Can't get full information for '%s': %s", profile, err)
} else {
Expand Down
32 changes: 15 additions & 17 deletions lib/ec2connect/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package ec2connect

import (
"aws-ssh/lib"
"context"

"fmt"
"net"
Expand All @@ -11,10 +12,10 @@ import (
"syscall"

"github.com/apex/log"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/ec2instanceconnect"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/ec2"
"github.com/aws/aws-sdk-go-v2/service/ec2instanceconnect"
"golang.org/x/crypto/ssh/agent"
)

Expand Down Expand Up @@ -97,18 +98,15 @@ func ConnectEC2(sshEntries lib.SSHEntries, sshConfigPath string, args []string)
// and returns the public (or private if public doesn't exist) address of the EC2 instance
func pushEC2Connect(profile, instanceID, instanceUser, pubKey string) (string, string, error) {
ctx := log.WithField("instance_id", instanceID)
localSession, err := session.NewSessionWithOptions(session.Options{
Config: aws.Config{},
cfg, err := config.LoadDefaultConfig(context.TODO(),
config.WithSharedConfigProfile(profile))

SharedConfigState: session.SharedConfigEnable,
Profile: profile,
})
if err != nil {
return "", "", fmt.Errorf("can't get aws session: %s", err)
}
ec2Svc := ec2.New(localSession)
ec2Result, err := ec2Svc.DescribeInstances(&ec2.DescribeInstancesInput{
InstanceIds: aws.StringSlice([]string{instanceID}),
ec2Svc := ec2.NewFromConfig(cfg)
ec2Result, err := ec2Svc.DescribeInstances(context.TODO(), &ec2.DescribeInstancesInput{
InstanceIds: []string{instanceID},
})
if err != nil {
return "", "", fmt.Errorf("can't get ec2 instance: %s", err)
Expand All @@ -119,7 +117,7 @@ func pushEC2Connect(profile, instanceID, instanceUser, pubKey string) (string, s
}

ec2Instance := ec2Result.Reservations[0].Instances[0]
ec2ICSvc := ec2instanceconnect.New(localSession)
ec2ICSvc := ec2instanceconnect.NewFromConfig(cfg)

// no username has been provided, so we try to get it fom the instance tag first
if instanceUser == "" {
Expand All @@ -136,17 +134,17 @@ func pushEC2Connect(profile, instanceID, instanceUser, pubKey string) (string, s

ctx.WithField("user", instanceUser).Info("pushing SSH key...")

if _, err := ec2ICSvc.SendSSHPublicKey(&ec2instanceconnect.SendSSHPublicKeyInput{
if _, err := ec2ICSvc.SendSSHPublicKey(context.TODO(), &ec2instanceconnect.SendSSHPublicKeyInput{
InstanceId: ec2Instance.InstanceId,
InstanceOSUser: aws.String(instanceUser),
AvailabilityZone: ec2Instance.Placement.AvailabilityZone,
SSHPublicKey: aws.String(pubKey),
}); err != nil {
return "", "", fmt.Errorf("can't push ssh key: %s", err)
}
var address = aws.StringValue(ec2Instance.PrivateIpAddress)
if aws.StringValue(ec2Instance.PublicIpAddress) != "" {
address = aws.StringValue(ec2Instance.PublicIpAddress)
var address = aws.ToString(ec2Instance.PrivateIpAddress)
if aws.ToString(ec2Instance.PublicIpAddress) != "" {
address = aws.ToString(ec2Instance.PublicIpAddress)
}
return address, instanceUser, nil
}
Loading

0 comments on commit 7ae258d

Please sign in to comment.