Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addng support for MFA #224

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 159 additions & 39 deletions beeswithmachineguns/bees.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,18 @@

import boto.ec2
import boto.exception

from boto.sts import STSConnection
from boto.ec2.connection import EC2Connection

import paramiko
import json
from collections import defaultdict
import time


STATE_FILENAME = os.path.expanduser('~/.bees')
CONFIG_FILENAME = os.path.expanduser('~/.bees_config')

# Utilities

Expand Down Expand Up @@ -117,10 +122,13 @@ def _get_security_group_id(connection, security_group_name, subnet):
print('The bees need a security group to run under. Need to open a port from where you are to the target subnet.')
return

security_groups = connection.get_all_security_groups(filters={'group-name': [security_group_name]})
print('in this function')
ec2_connection = connection

security_groups = ec2_connection.get_all_security_groups(filters={'group-name': [security_group_name]})

if not security_groups:
security_groups = connection.get_all_security_groups(filters={'group-id': [security_group_name]})
security_groups = ec2_connection.get_all_security_groups(filters={'group-id': [security_group_name]})
if not security_groups:
print('The bees need a security group to run under. The one specified was not found.')
return
Expand All @@ -134,11 +142,57 @@ def up(count, group, zone, image_id, instance_type, username, key_name, subnet,
Startup the load testing server.
"""

try:

file_exists = os.path.isfile(CONFIG_FILENAME)
if file_exists:
file = open(CONFIG_FILENAME, "r")
lines = file.readlines()
remembered_mfa_serial=lines[0].replace("\n","")
remembered_region=lines[1]
mfa_serial = raw_input("Enter the MFA serial [" + remembered_mfa_serial + "]: ") or remembered_mfa_serial
else:
mfa_serial = raw_input("Enter the MFA serial (for example arn:aws:iam::1234567891011:mfa/myusername): ")
mfa_TOTP = raw_input("Enter the MFA code: ")

sts_connection = STSConnection()

tempCredentials = sts_connection.get_session_token(
duration=3600,
mfa_serial_number=mfa_serial,
mfa_token=mfa_TOTP
)

region = boto.ec2.get_region(_get_region(zone))
ec2_connection = EC2Connection(
region=region,
aws_access_key_id=tempCredentials.access_key,
aws_secret_access_key=tempCredentials.secret_key,
security_token=tempCredentials.session_token
)

file = open(CONFIG_FILENAME, "w")
file.write(mfa_serial + "\n")
file.write(_get_region(zone))
file.close()

except boto.exception.NoAuthHandlerFound as e:
print("Authenciation config error, perhaps you do not have a ~/.boto file with correct permissions?")
print(e.message)
return e
except Exception as e:
print("Unknown error occured:")
print(e.message)
return e

if ec2_connection == None:
raise Exception("Invalid zone specified? Unable to connect to region using zone name")

existing_username, existing_key_name, existing_zone, instance_ids = _read_server_list(zone)

count = int(count)
if existing_username == username and existing_key_name == key_name and existing_zone == zone:
ec2_connection = boto.ec2.connect_to_region(_get_region(zone))

existing_reservations = ec2_connection.get_all_instances(instance_ids=instance_ids)
existing_instances = [instance for reservation in existing_reservations for instance in reservation.instances if instance.state == 'running']
# User, key and zone match existing values and instance ids are found on state file
Expand All @@ -165,41 +219,38 @@ def up(count, group, zone, image_id, instance_type, username, key_name, subnet,
print('Warning. No key file found for %s. You will need to add this key to your SSH agent to connect.' % pem_path)

print('Connecting to the hive.')

try:
ec2_connection = boto.ec2.connect_to_region(_get_region(zone))
except boto.exception.NoAuthHandlerFound as e:
print("Authenciation config error, perhaps you do not have a ~/.boto file with correct permissions?")
print(e.message)
return e
except Exception as e:
print("Unknown error occured:")
print(e.message)
return e

if ec2_connection == None:
raise Exception("Invalid zone specified? Unable to connect to region using zone name")

groupId = group if subnet is None else _get_security_group_id(ec2_connection, group, subnet)
print("GroupId found: %s" % groupId)

placement = None if 'gov' in zone else zone
print("Placement: %s" % placement)

if bid:
print('Attempting to call up %i spot bees, this can take a while...' % count)

spot_requests = ec2_connection.request_spot_instances(
image_id=image_id,
price=bid,
count=count,
key_name=key_name,
security_group_ids=[groupId],
instance_type=instance_type,
placement=placement,
subnet_id=subnet)

# it can take a few seconds before the spot requests are fully processed
print('Attempting to call up %i spot bees, this can take a while...' % count)

if "sg-" not in groupId:
spot_requests = ec2_connection.request_spot_instances(
image_id=image_id,
price=bid,
count=count,
key_name=key_name,
security_groups=[groupId],
instance_type=instance_type,
placement=placement,
subnet_id=subnet)

else:
spot_requests = ec2_connection.request_spot_instances(
image_id=image_id,
price=bid,
count=count,
key_name=key_name,
security_group_ids=[groupId],
instance_type=instance_type,
placement=placement,
subnet_id=subnet)

# it can take a few seconds before the spot requests are fully processed
time.sleep(5)

instances = _wait_for_spot_request_fulfillment(ec2_connection, spot_requests)
Expand Down Expand Up @@ -272,7 +323,29 @@ def _check_instances():
print('No bees have been mobilized.')
return

ec2_connection = boto.ec2.connect_to_region(_get_region(zone))
file = open(CONFIG_FILENAME, "r")
lines = file.readlines()
remembered_mfa_serial=lines[0].replace("\n","")
remembered_region=lines[1]

mfa_serial = raw_input("Enter the MFA serial [" + remembered_mfa_serial + "]: ") or remembered_mfa_serial
mfa_TOTP = raw_input("Enter the MFA code: ")

sts_connection = STSConnection()

tempCredentials = sts_connection.get_session_token(
duration=3600,
mfa_serial_number=mfa_serial,
mfa_token=mfa_TOTP
)

region = boto.ec2.get_region(remembered_region)
ec2_connection = EC2Connection(
region=region,
aws_access_key_id=tempCredentials.access_key,
aws_secret_access_key=tempCredentials.secret_key,
security_token=tempCredentials.session_token
)

reservations = ec2_connection.get_all_instances(instance_ids=instance_ids)

Expand All @@ -292,17 +365,41 @@ def down(*mr_zone):
"""
Shutdown the load testing server.
"""
def _check_to_down_it():

def _check_to_down_it(region):
'''check if we can bring down some bees'''
username, key_name, zone, instance_ids = _read_server_list(region)

username, key_name, zone, instance_ids = _read_server_list(region)

if not instance_ids:
print('No bees have been mobilized.')
return

print('Connecting to the hive.')

ec2_connection = boto.ec2.connect_to_region(_get_region(zone))
file = open(CONFIG_FILENAME, "r")
lines = file.readlines()
remembered_mfa_serial=lines[0].replace("\n","")
remembered_region=lines[1]

mfa_serial = raw_input("Enter the MFA serial [" + remembered_mfa_serial + "]: ") or remembered_mfa_serial
mfa_TOTP = raw_input("Enter the MFA code: ")

sts_connection = STSConnection()

tempCredentials = sts_connection.get_session_token(
duration=3600,
mfa_serial_number=mfa_serial,
mfa_token=mfa_TOTP
)

region = boto.ec2.get_region(remembered_region)
ec2_connection = EC2Connection(
region=region,
aws_access_key_id=tempCredentials.access_key,
aws_secret_access_key=tempCredentials.secret_key,
security_token=tempCredentials.session_token
)

print(('Calling off the swarm for {}.').format(region))

Expand All @@ -313,12 +410,11 @@ def _check_to_down_it():

_delete_server_list(zone)


if len(mr_zone) > 0:
username, key_name, zone, instance_ids = _read_server_list(mr_zone[-1])
else:
for region in _get_existing_regions():
_check_to_down_it()
_check_to_down_it(region)

def _wait_for_spot_request_fulfillment(conn, requests, fulfilled_requests = []):
"""
Expand Down Expand Up @@ -842,7 +938,29 @@ def hurl_attack(url, n, c, **options):

print('Connecting to the hive.')

ec2_connection = boto.ec2.connect_to_region(_get_region(zone))
file = open(CONFIG_FILENAME, "r")
lines = file.readlines()
remembered_mfa_serial=lines[0].replace("\n","")
remembered_region=lines[1]

mfa_serial = raw_input("Enter the MFA serial [" + remembered_mfa_serial + "]: ") or remembered_mfa_serial
mfa_TOTP = raw_input("Enter the MFA code: ")

sts_connection = STSConnection()

tempCredentials = sts_connection.get_session_token(
duration=3600,
mfa_serial_number=mfa_serial,
mfa_token=mfa_TOTP
)

region = boto.ec2.get_region(remembered_region)
ec2_connection = EC2Connection(
region=region,
aws_access_key_id=tempCredentials.access_key,
aws_secret_access_key=tempCredentials.secret_key,
security_token=tempCredentials.session_token
)

print('Assembling bees.')

Expand Down Expand Up @@ -1307,3 +1425,5 @@ def _get_existing_regions():
something= re.search(r'\.bees\.(.*)', f)
existing_regions.append( something.group(1)) if something else "no"
return existing_regions