Skip to content

Commit

Permalink
Dynamically generate ssh config for ec2 connect if instance id is
Browse files Browse the repository at this point in the history
provided
  • Loading branch information
Eugene Dementyev authored and ekini committed Jul 20, 2021
1 parent b07cebe commit 5f3f705
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 23 deletions.
15 changes: 8 additions & 7 deletions cmd/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,19 @@ the first public key from your running ssh agent and then runs ssh command`,
Run: func(cmd *cobra.Command, args []string) {
var sshEntries lib.SSHEntries
var profile string
var instanceId = viper.GetString("instanceid")
var instanceID = viper.GetString("instanceid")
var defaultUser = viper.GetString("user")

profiles := viper.GetStringSlice("profiles")
if len(profiles) > 0 {
profile = profiles[0]
ec2connect.ConnectEC2(
lib.SSHEntries{
lib.SSHEntry{
&lib.SSHEntry{
ProfileConfig: lib.ProfileConfig{Name: profile},
InstanceID: instanceId,
InstanceID: instanceID,
User: defaultUser,
Names: []string{instanceID},
},
},
viper.GetString("ssh-config-path"),
Expand All @@ -55,15 +56,15 @@ the first public key from your running ssh agent and then runs ssh command`,
log.Info("No profile has been provided, switching to the cache mode")
cache := cache.NewYAMLCache(viper.GetString("cache-dir"))

sshEntry, err := cache.Lookup(instanceId)
sshEntry, err := cache.Lookup(instanceID)
if err != nil {
log.WithError(err).Fatalf("can't lookup %s in cache", instanceId)
log.WithError(err).Fatalf("can't lookup %s in cache", instanceID)
}
if sshEntry.User == "" {
sshEntry.User = defaultUser
}

sshEntries = append(sshEntries, sshEntry)
sshEntries = append(sshEntries, &sshEntry)
// ProxyJump is set, which means we need to lookup the bastion host too
if sshEntry.ProxyJump != "" {
bastionEntry, err := cache.Lookup(sshEntry.ProxyJump)
Expand All @@ -74,7 +75,7 @@ the first public key from your running ssh agent and then runs ssh command`,
bastionEntry.User = defaultUser
}
log.WithField("instance_id", bastionEntry.InstanceID).Infof("Got bastion %s", bastionEntry.Names[0])
sshEntries = append(sshEntries, bastionEntry)
sshEntries = append(sshEntries, &bastionEntry)
}
ec2connect.ConnectEC2(sshEntries, viper.GetString("ssh-config-path"), args)
},
Expand Down
1 change: 1 addition & 0 deletions lib/cache/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"aws-ssh/lib"
)

// Cache represents the cache for profiles
type Cache interface {
// Load() loads the cache
Load() ([]lib.ProcessedProfileSummary, error)
Expand Down
43 changes: 28 additions & 15 deletions lib/ec2connect/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,6 @@ import (
// ConnectEC2 connects to an EC2 instance by pushing your public key onto it first
// using EC2 connect feature and then runs ssh.
func ConnectEC2(sshEntries lib.SSHEntries, sshConfigPath string, args []string) {
// save the dynamic ssh config first
if err := sshEntries.SaveConfig(sshConfigPath); err != nil {
log.WithError(err).Fatal("can't save ssh config for ec2 connect")
}

// get the pub key from the ssh agent first
sshAgent, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK"))

Expand All @@ -35,8 +30,6 @@ func ConnectEC2(sshEntries lib.SSHEntries, sshConfigPath string, args []string)
}
pubkey := keys[0].String()

// then generate ssh config for all instances in sshEntries

// push the pub key to those instances one after each other
// TODO: make it parallel
for _, sshEntry := range sshEntries {
Expand All @@ -46,19 +39,33 @@ func ConnectEC2(sshEntries lib.SSHEntries, sshConfigPath string, args []string)
}

log.WithField("instance", instanceName).WithField("user", sshEntry.User).Info("Pushing SSH key...")
err = pushEC2Connect(sshEntry.ProfileConfig.Name, sshEntry.InstanceID, sshEntry.User, pubkey)
instanceIpAddress, err := pushEC2Connect(sshEntry.ProfileConfig.Name, sshEntry.InstanceID, sshEntry.User, pubkey)
if err != nil {
log.WithError(err).Fatal("can't push ssh key to the instance")
}
// if the address is empty we set to the value we got from ec2 connect push
if sshEntry.Address == "" {
sshEntry.Address = instanceIpAddress
}
}

// then generate ssh config for all instances in sshEntries
// save the dynamic ssh config first
if err := sshEntries.SaveConfig(sshConfigPath); err != nil {
log.WithError(err).Fatal("can't save ssh config for ec2 connect")
}

var instanceName = sshEntries[0].InstanceID
if len(sshEntries[0].Names) > 0 {
instanceName = sshEntries[0].Names[0]
}
// connect to the first instance in sshEntry, as the other will be bastion(s)
if len(args) == 0 {
// construct default args
args = []string{
"ssh",
"-tt",
fmt.Sprintf(sshEntries[0].Names[0]),
instanceName,
}
}

Expand All @@ -73,26 +80,28 @@ func ConnectEC2(sshEntries lib.SSHEntries, sshConfigPath string, args []string)
}
}

func pushEC2Connect(profile, instanceID, instanceUser, pubKey string) error {
// pushEC2Connect pushes the ssh key to a given profile and instance ID
// and returns the public (or private if public doesn't exist) address of the EC2 instance
func pushEC2Connect(profile, instanceID, instanceUser, pubKey string) (string, error) {
localSession, err := session.NewSessionWithOptions(session.Options{
Config: aws.Config{},

SharedConfigState: session.SharedConfigEnable,
Profile: profile,
})
if err != nil {
return fmt.Errorf("can't get aws session: %s", err)
return "", fmt.Errorf("can't get aws session: %s", err)
}
ec2Svc := ec2.New(localSession)
ec2Result, err := ec2Svc.DescribeInstances(&ec2.DescribeInstancesInput{
InstanceIds: aws.StringSlice([]string{instanceID}),
})
if err != nil {
return fmt.Errorf("can't get ec2 instance: %s", err)
return "", fmt.Errorf("can't get ec2 instance: %s", err)
}

if len(ec2Result.Reservations) == 0 || len(ec2Result.Reservations[0].Instances) == 0 {
return fmt.Errorf("Couldn't find the instance %s", instanceID)
return "", fmt.Errorf("Couldn't find the instance %s", instanceID)
}

ec2Instance := ec2Result.Reservations[0].Instances[0]
Expand All @@ -104,7 +113,11 @@ func pushEC2Connect(profile, instanceID, instanceUser, pubKey string) error {
AvailabilityZone: ec2Instance.Placement.AvailabilityZone,
SSHPublicKey: aws.String(pubKey),
}); err != nil {
return fmt.Errorf("can't push ssh key: %s", err)
return "", fmt.Errorf("can't push ssh key: %s", err)
}
var address = aws.StringValue(ec2Instance.PrivateIpAddress)
if aws.StringValue(ec2Instance.PublicIpAddress) != "" {
address = aws.StringValue(ec2Instance.PublicIpAddress)
}
return nil
return address, nil
}
4 changes: 3 additions & 1 deletion lib/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ type ProfileConfig struct {
Domain string // domain if set with "aws-ssh-domain" in the config
}

type SSHEntries []SSHEntry
// SSHEntries is a list of SSHEntry with additional function
// to save the config to a file
type SSHEntries []*SSHEntry

// SaveConfig saves the ssh config for the entries
func (e SSHEntries) SaveConfig(path string) error {
Expand Down

0 comments on commit 5f3f705

Please sign in to comment.