diff --git a/pkg/aws/aws_client/instance.go b/pkg/aws/aws_client/instance.go index f6d2fab..df3e29a 100644 --- a/pkg/aws/aws_client/instance.go +++ b/pkg/aws/aws_client/instance.go @@ -15,16 +15,32 @@ import ( "github.com/openshift-online/ocm-common/pkg/log" ) -func (client *AWSClient) LaunchInstance(subnetID string, imageID string, count int, instanceType string, keyName string, securityGroupIds []string, wait bool) (*ec2.RunInstancesOutput, error) { - input := &ec2.RunInstancesInput{ - ImageId: aws.String(imageID), - MinCount: aws.Int32(int32(count)), - MaxCount: aws.Int32(int32(count)), - InstanceType: types.InstanceType(instanceType), - KeyName: aws.String(keyName), - SecurityGroupIds: securityGroupIds, - SubnetId: &subnetID, +func (client *AWSClient) LaunchInstance(subnetID string, imageID string, count int, instanceType string, keyName string, + securityGroupIds []string, wait bool, userDate ...string) (*ec2.RunInstancesOutput, error) { + var input *ec2.RunInstancesInput + if len(userDate) > 0 { + input = &ec2.RunInstancesInput{ + ImageId: aws.String(imageID), + MinCount: aws.Int32(int32(count)), + MaxCount: aws.Int32(int32(count)), + InstanceType: types.InstanceType(instanceType), + KeyName: aws.String(keyName), + SecurityGroupIds: securityGroupIds, + SubnetId: &subnetID, + UserData: &userDate[0], + } + } else { + input = &ec2.RunInstancesInput{ + ImageId: aws.String(imageID), + MinCount: aws.Int32(int32(count)), + MaxCount: aws.Int32(int32(count)), + InstanceType: types.InstanceType(instanceType), + KeyName: aws.String(keyName), + SecurityGroupIds: securityGroupIds, + SubnetId: &subnetID, + } } + output, err := client.Ec2Client.RunInstances(context.TODO(), input) if wait && err == nil { instanceIDs := []string{} diff --git a/pkg/test/vpc_client/bastion.go b/pkg/test/vpc_client/bastion.go index a30dc9b..7edcdeb 100644 --- a/pkg/test/vpc_client/bastion.go +++ b/pkg/test/vpc_client/bastion.go @@ -1,20 +1,21 @@ package vpc_client import ( + "encoding/base64" "fmt" - "time" - "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "net" CON "github.com/openshift-online/ocm-common/pkg/aws/consts" "github.com/openshift-online/ocm-common/pkg/log" ) // LaunchBastion will launch a bastion instance on the indicated zone. -// If set imageID to empty, it will find the bastion image in the bastionImageMap map -func (vpc *VPC) LaunchBastion(imageID string, zone string) (*types.Instance, error) { +// If set imageID to empty, it will find the bastion image using filter with specific name. +func (vpc *VPC) LaunchBastion(imageID string, zone string, userData ...string) (*types.Instance, error) { var inst *types.Instance if imageID == "" { + var err error imageID, err = vpc.FindProxyLaunchImage() if err != nil { @@ -24,10 +25,10 @@ func (vpc *VPC) LaunchBastion(imageID string, zone string) (*types.Instance, err } pubSubnet, err := vpc.PreparePublicSubnet(zone) if err != nil { - log.LogInfo("Error preparing a subnet in current zone %s with image ID %s: %s", zone, imageID, err) + log.LogError("Error preparing a subnet in current zone %s with image ID %s: %s", zone, imageID, err) return nil, err } - SGID, err := vpc.CreateAndAuthorizeDefaultSecurityGroupForProxy() + SGID, err := vpc.CreateAndAuthorizeDefaultSecurityGroupForProxy(3128) if err != nil { log.LogError("Prepare SG failed for the bastion preparation %s", err) return inst, err @@ -38,7 +39,8 @@ func (vpc *VPC) LaunchBastion(imageID string, zone string) (*types.Instance, err log.LogError("Create key pair failed %s", err) return inst, err } - instOut, err := vpc.AWSClient.LaunchInstance(pubSubnet.ID, imageID, 1, "t3.medium", *key.KeyName, []string{SGID}, true) + instOut, err := vpc.AWSClient.LaunchInstance(pubSubnet.ID, imageID, 1, "t3.medium", *key.KeyName, + []string{SGID}, true, userData...) if err != nil { log.LogError("Launch bastion instance failed %s", err) @@ -63,11 +65,10 @@ func (vpc *VPC) LaunchBastion(imageID string, zone string) (*types.Instance, err log.LogInfo("Prepare EIP successfully for the bastion preparation. Launch with IP: %s", publicIP) inst = &instOut.Instances[0] inst.PublicIpAddress = &publicIP - time.Sleep(2 * time.Minute) return inst, nil } -func (vpc *VPC) PrepareBastion(zone string) (*types.Instance, error) { +func (vpc *VPC) PrepareBastionProxy(zone string, cidrIp string) (*types.Instance, error) { filters := []map[string][]string{ { "vpc-id": { @@ -87,9 +88,64 @@ func (vpc *VPC) PrepareBastion(zone string) (*types.Instance, error) { } if len(insts) == 0 { log.LogInfo("Didn't found an existing bastion, going to launch one") - return vpc.LaunchBastion("", zone) + if cidrIp == "" { + cidrIp = CON.RouteDestinationCidrBlock + } + _, _, err = net.ParseCIDR(cidrIp) + if err != nil { + log.LogError("CIDR IP address format is invalid") + return nil, err + } + userData := fmt.Sprintf("#!/bin/bash\n\t\t"+ + "yum update -y\n\t\t"+ + "yum install -y squid\n\t\t"+ + "cd /etc/squid/\n\t\t"+ + "sudo mv ./squid.conf ./squid.conf.bak\n\t\t"+ + "sudo touch squid.conf\n\t\t"+ + "echo \"http_port 3128\" >> /etc/squid/squid.conf\n\t\t"+ + "echo \"acl allowed_ips src %s\" >> /etc/squid/squid.conf\n\t\t"+ + "echo \"http_access allow allowed_ips\" >> /etc/squid/squid.conf\n\t\t"+ + "echo \"http_access deny all\" >> /etc/squid/squid.conf\n\t\t"+ + "systemctl start squid\n\t\t"+ + "systemctl enable squid", cidrIp) + + encodeUserData := base64.StdEncoding.EncodeToString([]byte(userData)) + regionZone := fmt.Sprintf("%s%s", vpc.Region, zone) + return vpc.LaunchBastion("", regionZone, encodeUserData) } log.LogInfo("Found existing bastion: %s", *insts[0].InstanceId) return &insts[0], nil } + +func (vpc *VPC) DestroyBastionProxy(instance types.Instance) error { + var instanceIDs []string + instanceIDs = append(instanceIDs, *instance.InstanceId) + err := vpc.AWSClient.TerminateInstances(instanceIDs, true, 10) + if err != nil { + log.LogError("Terminate instance failed") + return err + } + + var keyNames []string + keyNames = append(keyNames, *instance.KeyName) + err = vpc.DeleteKeyPair(keyNames) + if err != nil { + log.LogError("Delete key pair failed") + return err + } + + err = vpc.DeleteVPCSecurityGroups(true) + if err != nil { + log.LogError("Delete VPC security group failed") + return err + } + + _, err = vpc.AWSClient.DeleteSubnet(*instance.SubnetId) + if err != nil { + log.LogError("Delete VPC public subnet failed") + return err + } + + return nil +} diff --git a/pkg/test/vpc_client/security_group.go b/pkg/test/vpc_client/security_group.go index cb334cc..e058cfc 100644 --- a/pkg/test/vpc_client/security_group.go +++ b/pkg/test/vpc_client/security_group.go @@ -36,10 +36,13 @@ func (vpc *VPC) DeleteVPCSecurityGroups(customizedOnly bool) error { } // CreateAndAuthorizeDefaultSecurityGroupForProxy can prepare a security group for the proxy launch -func (vpc *VPC) CreateAndAuthorizeDefaultSecurityGroupForProxy() (string, error) { +func (vpc *VPC) CreateAndAuthorizeDefaultSecurityGroupForProxy(ports ...int32) (string, error) { var groupID string var err error - sgIDs, err := vpc.CreateAdditionalSecurityGroups(1, con.ProxySecurityGroupName, con.ProxySecurityGroupDescription) + var sgIDs []string + sgIDs, err = vpc.CreateAdditionalSecurityGroups(1, con.ProxySecurityGroupName, + con.ProxySecurityGroupDescription, ports...) + if err != nil { log.LogError("Security group prepare for proxy failed") } else { @@ -52,7 +55,7 @@ func (vpc *VPC) CreateAndAuthorizeDefaultSecurityGroupForProxy() (string, error) // CreateAdditionalSecurityGroups can prepare additional security groups // description can be empty which will be set to default value // namePrefix is required, otherwise if there is same security group existing the creation will fail -func (vpc *VPC) CreateAdditionalSecurityGroups(count int, namePrefix string, description string) ([]string, error) { +func (vpc *VPC) CreateAdditionalSecurityGroups(count int, namePrefix string, description string, ports ...int32) ([]string, error) { preparedSGs := []string{} createdsgNum := 0 if description == "" { @@ -65,15 +68,21 @@ func (vpc *VPC) CreateAdditionalSecurityGroups(count int, namePrefix string, des panic(err) } groupID := *sg.GroupId - cidrPortsMap := map[string]int32{ - vpc.CIDRValue: 8080, - con.RouteDestinationCidrBlock: 22, + cidrPortsMap := make(map[string][]int32) + cidrPortsMap[con.RouteDestinationCidrBlock] = append(cidrPortsMap[con.RouteDestinationCidrBlock], 22) + if len(ports) > 0 { + cidrPortsMap[con.RouteDestinationCidrBlock] = append(cidrPortsMap[con.RouteDestinationCidrBlock], ports...) + } else { + cidrPortsMap[vpc.CIDRValue] = append(cidrPortsMap[vpc.CIDRValue], 8080) } for cidr, port := range cidrPortsMap { - _, err = vpc.AWSClient.AuthorizeSecurityGroupIngress(groupID, cidr, con.TCPProtocol, port, port) - if err != nil { - return preparedSGs, err + for _, v := range port { + _, err = vpc.AWSClient.AuthorizeSecurityGroupIngress(groupID, cidr, con.TCPProtocol, v, v) + if err != nil { + return preparedSGs, err + } } + } preparedSGs = append(preparedSGs, *sg.GroupId)