Skip to content

Commit

Permalink
Default region + error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
krystian-panek-vmltech committed Jan 26, 2024
1 parent e41d5d6 commit 96d905a
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 19 deletions.
1 change: 0 additions & 1 deletion examples/aws_ssm/aem.tf
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ resource "aem_instance" "single" {
type = "aws-ssm"
settings = {
instance_id = aws_instance.aem_single.id
region = "eu-central-1" // TODO infer from AWS provider config
}
}

Expand Down
3 changes: 3 additions & 0 deletions internal/client/client_manager.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package client

import (
"context"
"fmt"
"github.com/spf13/cast"
)
Expand Down Expand Up @@ -42,6 +43,8 @@ func (c ClientManager) connection(typeName string, settings map[string]string) (
return &AWSSSMConnection{
InstanceID: settings["instance_id"],
Region: settings["region"],

context: context.Background(),
}, nil
}
return nil, fmt.Errorf("unknown AEM client type: %s", typeName)
Expand Down
38 changes: 20 additions & 18 deletions internal/client/connection_aws_ssm.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,37 @@ type AWSSSMConnection struct {
Region string
Client *ssm.Client
SessionId *string
Context context.Context

context context.Context
}

func (a *AWSSSMConnection) Info() string {
return fmt.Sprintf("ssm: instance_id='%s', region='%s'", a.InstanceID, a.Region)
}

func (a *AWSSSMConnection) User() string {
out, _ := a.Command([]string{"whoami"})
out, err := a.Command([]string{"whoami"})
if err != nil {
panic(fmt.Sprintf("ssm: cannot determine connected user: %s", err))
}
return strings.TrimSpace(string(out))
}

func (a *AWSSSMConnection) Connect() error {
// Specify the AWS region
a.Context = context.Background()
cfg, err := config.LoadDefaultConfig(a.Context, config.WithRegion(a.Region))
var optFns []func(*config.LoadOptions) error
if a.Region != "" {
optFns = append(optFns, config.WithRegion(a.Region))
}

cfg, err := config.LoadDefaultConfig(a.context, optFns...)
if err != nil {
return err
}

// Create an SSM client
client := ssm.NewFromConfig(cfg)
startSessionInput := &ssm.StartSessionInput{Target: aws.String(a.InstanceID)}

startSessionOutput, err := client.StartSession(a.Context, startSessionInput)
startSessionOutput, err := client.StartSession(a.context, startSessionInput)
if err != nil {
return fmt.Errorf("ssm: error starting session: %v", err)
}
Expand All @@ -56,7 +62,7 @@ func (a *AWSSSMConnection) Disconnect() error {
// Disconnect from the session
terminateSessionInput := &ssm.TerminateSessionInput{SessionId: a.SessionId}

_, err := a.Client.TerminateSession(a.Context, terminateSessionInput)
_, err := a.Client.TerminateSession(a.context, terminateSessionInput)
if err != nil {
return fmt.Errorf("ssm: error terminating session: %v", err)
}
Expand All @@ -65,7 +71,6 @@ func (a *AWSSSMConnection) Disconnect() error {
}

func (a *AWSSSMConnection) Command(cmdLine []string) ([]byte, error) {
// Execute command on the remote instance
command := strings.Join(cmdLine, " ")
runCommandInput := &ssm.SendCommandInput{
DocumentName: aws.String("AWS-RunShellScript"),
Expand All @@ -74,31 +79,28 @@ func (a *AWSSSMConnection) Command(cmdLine []string) ([]byte, error) {
"commands": {command},
},
}

runCommandOutput, err := a.Client.SendCommand(a.Context, runCommandInput)
runOut, err := a.Client.SendCommand(a.context, runCommandInput)
if err != nil {
return nil, fmt.Errorf("ssm: error executing command: %v", err)
}

commandId := runCommandOutput.Command.CommandId

commandInvocationInput := &ssm.GetCommandInvocationInput{
commandId := runOut.Command.CommandId
invocationIn := &ssm.GetCommandInvocationInput{
CommandId: commandId,
InstanceId: aws.String(a.InstanceID),
}

waiter := ssm.NewCommandExecutedWaiter(a.Client)
_, err = waiter.WaitForOutput(a.Context, commandInvocationInput, time.Hour)
_, err = waiter.WaitForOutput(a.context, invocationIn, time.Hour)
if err != nil {
return nil, fmt.Errorf("ssm: error executing command: %v", err)
}

getCommandOutput, err := a.Client.GetCommandInvocation(a.Context, commandInvocationInput)
invocationOut, err := a.Client.GetCommandInvocation(a.context, invocationIn)
if err != nil {
return nil, fmt.Errorf("ssm: error executing command: %v", err)
}

return []byte(*getCommandOutput.StandardOutputContent), nil
return []byte(*invocationOut.StandardOutputContent), nil
}

func (a *AWSSSMConnection) CopyFile(localPath string, remotePath string) error {
Expand Down

0 comments on commit 96d905a

Please sign in to comment.