diff --git a/pkg/aws/aws_client/instance.go b/pkg/aws/aws_client/instance.go index f6d2fab..b6462e2 100644 --- a/pkg/aws/aws_client/instance.go +++ b/pkg/aws/aws_client/instance.go @@ -15,7 +15,8 @@ import ( "github.com/openshift-online/ocm-common/pkg/log" ) -func (client *AWSClient) LaunchInstance(subnetID string, imageID string, count int, instanceType string, keyName string, securityGroupIds []string, wait bool) (*ec2.RunInstancesOutput, error) { +func (client *AWSClient) LaunchInstance(subnetID string, imageID string, count int, instanceType string, keyName string, + securityGroupIds []string, wait bool, userDate ...string) (*ec2.RunInstancesOutput, error) { input := &ec2.RunInstancesInput{ ImageId: aws.String(imageID), MinCount: aws.Int32(int32(count)), @@ -25,6 +26,10 @@ func (client *AWSClient) LaunchInstance(subnetID string, imageID string, count i SecurityGroupIds: securityGroupIds, SubnetId: &subnetID, } + if len(userDate) > 0 { + input.UserData = &userDate[0] + } + output, err := client.Ec2Client.RunInstances(context.TODO(), input) if wait && err == nil { instanceIDs := []string{} diff --git a/pkg/aws/consts/consts.go b/pkg/aws/consts/consts.go index e1f0b7d..a3a4227 100644 --- a/pkg/aws/consts/consts.go +++ b/pkg/aws/consts/consts.go @@ -41,9 +41,11 @@ const ( TCPProtocol = "tcp" UDPProtocol = "udp" + BastionSecurityGroupName = "bastion-sg" ProxySecurityGroupName = "proxy-sg" AdditionalSecurityGroupName = "ocm-additional-sg" ProxySecurityGroupDescription = "security group for proxy" + BastionSecurityGroupDescription = "security group for bastion" DefaultAdditionalSecurityGroupDescription = "This security group is created for OCM testing" QEFlagKey = "ocm_ci_flag" diff --git a/pkg/test/vpc_client/bastion.go b/pkg/test/vpc_client/bastion.go index a30dc9b..6b04d4f 100644 --- a/pkg/test/vpc_client/bastion.go +++ b/pkg/test/vpc_client/bastion.go @@ -1,20 +1,22 @@ package vpc_client import ( + "encoding/base64" + "errors" "fmt" - "time" - "github.com/aws/aws-sdk-go-v2/service/ec2/types" + "net" CON "github.com/openshift-online/ocm-common/pkg/aws/consts" "github.com/openshift-online/ocm-common/pkg/log" ) // LaunchBastion will launch a bastion instance on the indicated zone. -// If set imageID to empty, it will find the bastion image in the bastionImageMap map -func (vpc *VPC) LaunchBastion(imageID string, zone string) (*types.Instance, error) { +// If set imageID to empty, it will find the bastion image using filter with specific name. +func (vpc *VPC) LaunchBastion(imageID string, zone string, userData string) (*types.Instance, error) { var inst *types.Instance if imageID == "" { + var err error imageID, err = vpc.FindProxyLaunchImage() if err != nil { @@ -22,12 +24,16 @@ func (vpc *VPC) LaunchBastion(imageID string, zone string) (*types.Instance, err return nil, err } } + if userData == "" { + log.LogError("Userdata can not be empty, pleas provide the correct userdata") + return nil, errors.New("userData should not be empty") + } pubSubnet, err := vpc.PreparePublicSubnet(zone) if err != nil { - log.LogInfo("Error preparing a subnet in current zone %s with image ID %s: %s", zone, imageID, err) + log.LogError("Error preparing a subnet in current zone %s with image ID %s: %s", zone, imageID, err) return nil, err } - SGID, err := vpc.CreateAndAuthorizeDefaultSecurityGroupForProxy() + SGID, err := vpc.CreateAndAuthorizeDefaultSecurityGroupForProxy(3128) if err != nil { log.LogError("Prepare SG failed for the bastion preparation %s", err) return inst, err @@ -38,7 +44,8 @@ func (vpc *VPC) LaunchBastion(imageID string, zone string) (*types.Instance, err log.LogError("Create key pair failed %s", err) return inst, err } - instOut, err := vpc.AWSClient.LaunchInstance(pubSubnet.ID, imageID, 1, "t3.medium", *key.KeyName, []string{SGID}, true) + instOut, err := vpc.AWSClient.LaunchInstance(pubSubnet.ID, imageID, 1, "t3.medium", *key.KeyName, + []string{SGID}, true, userData) if err != nil { log.LogError("Launch bastion instance failed %s", err) @@ -63,11 +70,10 @@ func (vpc *VPC) LaunchBastion(imageID string, zone string) (*types.Instance, err log.LogInfo("Prepare EIP successfully for the bastion preparation. Launch with IP: %s", publicIP) inst = &instOut.Instances[0] inst.PublicIpAddress = &publicIP - time.Sleep(2 * time.Minute) return inst, nil } -func (vpc *VPC) PrepareBastion(zone string) (*types.Instance, error) { +func (vpc *VPC) PrepareBastionProxy(zone string, cidrBlock string) (*types.Instance, error) { filters := []map[string][]string{ { "vpc-id": { @@ -87,9 +93,43 @@ func (vpc *VPC) PrepareBastion(zone string) (*types.Instance, error) { } if len(insts) == 0 { log.LogInfo("Didn't found an existing bastion, going to launch one") - return vpc.LaunchBastion("", zone) + if cidrBlock == "" { + cidrBlock = CON.RouteDestinationCidrBlock + } + _, _, err = net.ParseCIDR(cidrBlock) + if err != nil { + log.LogError("CIDR IP address format is invalid") + return nil, err + } + + userData := fmt.Sprintf(`#!/bin/bash + yum update -y + yum install -y squid + cd /etc/squid/ + sudo mv ./squid.conf ./squid.conf.bak + sudo touch squid.conf + echo http_port 3128 >> /etc/squid/squid.conf + echo acl allowed_ips src %s >> /etc/squid/squid.conf + echo http_access allow allowed_ips >> /etc/squid/squid.conf + echo http_access deny all >> /etc/squid/squid.conf + systemctl start squid + systemctl enable squid`, cidrBlock) + + encodeUserData := base64.StdEncoding.EncodeToString([]byte(userData)) + return vpc.LaunchBastion("", zone, encodeUserData) } log.LogInfo("Found existing bastion: %s", *insts[0].InstanceId) return &insts[0], nil } + +func (vpc *VPC) DestroyBastionProxy(instance types.Instance) error { + var instanceIDs []string + instanceIDs = append(instanceIDs, *instance.InstanceId) + err := vpc.AWSClient.TerminateInstances(instanceIDs, true, 10) + if err != nil { + log.LogError("Terminate instance failed") + return err + } + return nil +} diff --git a/pkg/test/vpc_client/security_group.go b/pkg/test/vpc_client/security_group.go index cb334cc..280b8d3 100644 --- a/pkg/test/vpc_client/security_group.go +++ b/pkg/test/vpc_client/security_group.go @@ -36,10 +36,21 @@ func (vpc *VPC) DeleteVPCSecurityGroups(customizedOnly bool) error { } // CreateAndAuthorizeDefaultSecurityGroupForProxy can prepare a security group for the proxy launch -func (vpc *VPC) CreateAndAuthorizeDefaultSecurityGroupForProxy() (string, error) { +func (vpc *VPC) CreateAndAuthorizeDefaultSecurityGroupForProxy(ports ...int32) (string, error) { var groupID string var err error - sgIDs, err := vpc.CreateAdditionalSecurityGroups(1, con.ProxySecurityGroupName, con.ProxySecurityGroupDescription) + var sgIDs []string + var securityGroupName string + var securityGroupDescription string + if len(ports) > 0 { + securityGroupName = con.BastionSecurityGroupName + securityGroupDescription = con.BastionSecurityGroupDescription + } else { + securityGroupName = con.ProxySecurityGroupName + securityGroupDescription = con.ProxySecurityGroupDescription + } + sgIDs, err = vpc.CreateAdditionalSecurityGroups(1, securityGroupName, securityGroupDescription, ports...) + if err != nil { log.LogError("Security group prepare for proxy failed") } else { @@ -52,7 +63,7 @@ func (vpc *VPC) CreateAndAuthorizeDefaultSecurityGroupForProxy() (string, error) // CreateAdditionalSecurityGroups can prepare additional security groups // description can be empty which will be set to default value // namePrefix is required, otherwise if there is same security group existing the creation will fail -func (vpc *VPC) CreateAdditionalSecurityGroups(count int, namePrefix string, description string) ([]string, error) { +func (vpc *VPC) CreateAdditionalSecurityGroups(count int, namePrefix string, description string, ports ...int32) ([]string, error) { preparedSGs := []string{} createdsgNum := 0 if description == "" { @@ -65,15 +76,21 @@ func (vpc *VPC) CreateAdditionalSecurityGroups(count int, namePrefix string, des panic(err) } groupID := *sg.GroupId - cidrPortsMap := map[string]int32{ - vpc.CIDRValue: 8080, - con.RouteDestinationCidrBlock: 22, + cidrPortsMap := make(map[string][]int32) + cidrPortsMap[con.RouteDestinationCidrBlock] = append(cidrPortsMap[con.RouteDestinationCidrBlock], 22) + if len(ports) > 0 { + cidrPortsMap[con.RouteDestinationCidrBlock] = append(cidrPortsMap[con.RouteDestinationCidrBlock], ports...) + } else { + cidrPortsMap[vpc.CIDRValue] = append(cidrPortsMap[vpc.CIDRValue], 8080) } for cidr, port := range cidrPortsMap { - _, err = vpc.AWSClient.AuthorizeSecurityGroupIngress(groupID, cidr, con.TCPProtocol, port, port) - if err != nil { - return preparedSGs, err + for _, v := range port { + _, err = vpc.AWSClient.AuthorizeSecurityGroupIngress(groupID, cidr, con.TCPProtocol, v, v) + if err != nil { + return preparedSGs, err + } } + } preparedSGs = append(preparedSGs, *sg.GroupId)