From 96d905a165f469d92785678ae0e62c0c2f183521 Mon Sep 17 00:00:00 2001 From: Krystian Panek Date: Fri, 26 Jan 2024 08:41:52 +0100 Subject: [PATCH] Default region + error handling --- examples/aws_ssm/aem.tf | 1 - internal/client/client_manager.go | 3 +++ internal/client/connection_aws_ssm.go | 38 ++++++++++++++------------- 3 files changed, 23 insertions(+), 19 deletions(-) diff --git a/examples/aws_ssm/aem.tf b/examples/aws_ssm/aem.tf index 10f35e3..c4fea01 100644 --- a/examples/aws_ssm/aem.tf +++ b/examples/aws_ssm/aem.tf @@ -5,7 +5,6 @@ resource "aem_instance" "single" { type = "aws-ssm" settings = { instance_id = aws_instance.aem_single.id - region = "eu-central-1" // TODO infer from AWS provider config } } diff --git a/internal/client/client_manager.go b/internal/client/client_manager.go index f76c463..f1f686a 100644 --- a/internal/client/client_manager.go +++ b/internal/client/client_manager.go @@ -1,6 +1,7 @@ package client import ( + "context" "fmt" "github.com/spf13/cast" ) @@ -42,6 +43,8 @@ func (c ClientManager) connection(typeName string, settings map[string]string) ( return &AWSSSMConnection{ InstanceID: settings["instance_id"], Region: settings["region"], + + context: context.Background(), }, nil } return nil, fmt.Errorf("unknown AEM client type: %s", typeName) diff --git a/internal/client/connection_aws_ssm.go b/internal/client/connection_aws_ssm.go index 825410b..7b7b5a1 100644 --- a/internal/client/connection_aws_ssm.go +++ b/internal/client/connection_aws_ssm.go @@ -17,7 +17,8 @@ type AWSSSMConnection struct { Region string Client *ssm.Client SessionId *string - Context context.Context + + context context.Context } func (a *AWSSSMConnection) Info() string { @@ -25,23 +26,28 @@ func (a *AWSSSMConnection) Info() string { } func (a *AWSSSMConnection) User() string { - out, _ := a.Command([]string{"whoami"}) + out, err := a.Command([]string{"whoami"}) + if err != nil { + panic(fmt.Sprintf("ssm: cannot determine connected user: %s", err)) + } return strings.TrimSpace(string(out)) } func (a *AWSSSMConnection) Connect() error { - // Specify the AWS region - a.Context = context.Background() - cfg, err := config.LoadDefaultConfig(a.Context, config.WithRegion(a.Region)) + var optFns []func(*config.LoadOptions) error + if a.Region != "" { + optFns = append(optFns, config.WithRegion(a.Region)) + } + + cfg, err := config.LoadDefaultConfig(a.context, optFns...) if err != nil { return err } - // Create an SSM client client := ssm.NewFromConfig(cfg) startSessionInput := &ssm.StartSessionInput{Target: aws.String(a.InstanceID)} - startSessionOutput, err := client.StartSession(a.Context, startSessionInput) + startSessionOutput, err := client.StartSession(a.context, startSessionInput) if err != nil { return fmt.Errorf("ssm: error starting session: %v", err) } @@ -56,7 +62,7 @@ func (a *AWSSSMConnection) Disconnect() error { // Disconnect from the session terminateSessionInput := &ssm.TerminateSessionInput{SessionId: a.SessionId} - _, err := a.Client.TerminateSession(a.Context, terminateSessionInput) + _, err := a.Client.TerminateSession(a.context, terminateSessionInput) if err != nil { return fmt.Errorf("ssm: error terminating session: %v", err) } @@ -65,7 +71,6 @@ func (a *AWSSSMConnection) Disconnect() error { } func (a *AWSSSMConnection) Command(cmdLine []string) ([]byte, error) { - // Execute command on the remote instance command := strings.Join(cmdLine, " ") runCommandInput := &ssm.SendCommandInput{ DocumentName: aws.String("AWS-RunShellScript"), @@ -74,31 +79,28 @@ func (a *AWSSSMConnection) Command(cmdLine []string) ([]byte, error) { "commands": {command}, }, } - - runCommandOutput, err := a.Client.SendCommand(a.Context, runCommandInput) + runOut, err := a.Client.SendCommand(a.context, runCommandInput) if err != nil { return nil, fmt.Errorf("ssm: error executing command: %v", err) } - commandId := runCommandOutput.Command.CommandId - - commandInvocationInput := &ssm.GetCommandInvocationInput{ + commandId := runOut.Command.CommandId + invocationIn := &ssm.GetCommandInvocationInput{ CommandId: commandId, InstanceId: aws.String(a.InstanceID), } - waiter := ssm.NewCommandExecutedWaiter(a.Client) - _, err = waiter.WaitForOutput(a.Context, commandInvocationInput, time.Hour) + _, err = waiter.WaitForOutput(a.context, invocationIn, time.Hour) if err != nil { return nil, fmt.Errorf("ssm: error executing command: %v", err) } - getCommandOutput, err := a.Client.GetCommandInvocation(a.Context, commandInvocationInput) + invocationOut, err := a.Client.GetCommandInvocation(a.context, invocationIn) if err != nil { return nil, fmt.Errorf("ssm: error executing command: %v", err) } - return []byte(*getCommandOutput.StandardOutputContent), nil + return []byte(*invocationOut.StandardOutputContent), nil } func (a *AWSSSMConnection) CopyFile(localPath string, remotePath string) error {