From e6aedf922e1fbac1bef4c868a4dd6e252c0a5f33 Mon Sep 17 00:00:00 2001 From: thevickypedia Date: Mon, 29 Jan 2024 21:42:14 -0600 Subject: [PATCH] Restructure code to remove redundancy and reduce overhead Add alias record right when the instance is ready This will allow time for DNS propagation before validation --- docs/index.html | 2 +- vpn/main.py | 174 +++++++++++++++++++++--------------------- vpn/models/config.py | 18 +++-- vpn/models/route53.py | 6 +- vpn/models/server.py | 14 ++-- 5 files changed, 107 insertions(+), 107 deletions(-) diff --git a/docs/index.html b/docs/index.html index 029b1d0..6ce7992 100644 --- a/docs/index.html +++ b/docs/index.html @@ -665,7 +665,7 @@

Configuration

SSH Configuration

-class vpn.models.server.Server(hostname: str, username: str, logger: Logger, env: EnvConfig, settings: Settings)
+class vpn.models.server.Server(hostname: str, username: str, logger: Logger)

Initiates Server object to create an SSH session to configure the server.

>>> Server
 
diff --git a/vpn/main.py b/vpn/main.py index 8efd406..775df47 100644 --- a/vpn/main.py +++ b/vpn/main.py @@ -16,12 +16,8 @@ from typing_extensions import Unpack from urllib3.exceptions import InsecureRequestWarning -from vpn.models.config import EnvConfig, Settings, configuration_dict -from vpn.models.exceptions import NotImplementedWarning -from vpn.models.image_factory import ImageFactory -from vpn.models.logger import LOGGER -from vpn.models.route53 import change_record_set, get_zone_id -from vpn.models.server import Server +from vpn.models import (config, exceptions, image_factory, logger, route53, + server) class VPNServer: @@ -31,7 +27,7 @@ class VPNServer: """ - def __init__(self, **kwargs: Unpack[Union[EnvConfig, Logger]]): + def __init__(self, **kwargs: Unpack[Union[config.EnvConfig, Logger]]): """Unpacks the kwargs and loads it as Envconfig, the subclass for BaseSettings validated using pydantic. Args: @@ -70,21 +66,21 @@ def __init__(self, **kwargs: Unpack[Union[EnvConfig, Logger]]): - | The alias record will fail as duplicate at DNS name resolution level unless a domain is | registered with the same name. """ - self.env = EnvConfig(**kwargs) - self.settings = Settings() - self.settings.key_pair_file = f"{self.env.key_pair}.pem" - if any((self.env.hosted_zone, self.env.subdomain)): - assert all((self.env.hosted_zone, self.env.subdomain)), "'subdomain' and 'hosted_zone' must co-exist" - self.settings.entrypoint = f'{self.env.subdomain}.{self.env.hosted_zone}' - self.settings.openvpn_config_commands = configuration_dict(self.env) - - self.logger = kwargs.get('logger') or LOGGER - self.session = boto3.Session(region_name=self.env.aws_region_name, - profile_name=self.env.aws_profile_name, - aws_access_key_id=self.env.aws_access_key, - aws_secret_access_key=self.env.aws_secret_key) + # Load any kwargs into EnvConfig + config.env = config.EnvConfig(**kwargs) + if any((config.env.hosted_zone, config.env.subdomain)): + assert all((config.env.hosted_zone, config.env.subdomain)), "'subdomain' and 'hosted_zone' must co-exist" + config.settings.entrypoint = f'{config.env.subdomain}.{config.env.hosted_zone}' + config.settings.key_pair_file = f"{config.env.key_pair}.pem" + config.settings.openvpn_config_commands = config.configuration_dict(config.env) + + self.logger = kwargs.get('logger') or logger.LOGGER + self.session = boto3.Session(region_name=config.env.aws_region_name, + profile_name=config.env.aws_profile_name, + aws_access_key_id=config.env.aws_access_key, + aws_secret_access_key=config.env.aws_secret_key) self.logger.info("Session instantiated for region: '%s' with '%s' instance", - self.session.region_name, self.env.instance_type) + self.session.region_name, config.env.instance_type) self.ec2_resource = self.session.resource(service_name='ec2') self.route53_client = self.session.client(service_name='route53') @@ -102,20 +98,20 @@ def _init(self, """ if start: # Not required during shutdown, since image_id is only used to create an ec2 instance variable = "created in" # var for logging if entrypoint is present - if self.env.image_id: - self.image_id = self.env.image_id + if config.env.image_id: + self.image_id = config.env.image_id else: - self.image_id = ImageFactory(self.session, self.logger).get_image_id() + self.image_id = image_factory.ImageFactory(self.session, self.logger).get_image_id() else: variable = "removed from" # var for logging if entrypoint is present - if self.env.hosted_zone: - self.zone_id = get_zone_id(client=self.route53_client, - logger=self.logger, - dns=self.env.hosted_zone, - init=True) - if self.settings.entrypoint: + if config.settings.entrypoint: + if not self.zone_id: + self.zone_id = route53.get_zone_id(client=self.route53_client, + logger=self.logger, + dns=config.env.hosted_zone, + init=start) self.logger.info("Entrypoint: '%s' will be %s the hosted zone [%s] '%s'", - self.settings.entrypoint, variable, self.zone_id, self.env.hosted_zone) + config.settings.entrypoint, variable, self.zone_id, config.env.hosted_zone) def _create_key_pair(self) -> bool: """Creates a ``KeyPair`` of type ``RSA`` stores as a ``PEM`` file for SSH connection. @@ -126,24 +122,24 @@ def _create_key_pair(self) -> bool: """ try: key_pair = self.ec2_resource.create_key_pair( - KeyName=self.env.key_pair, + KeyName=config.env.key_pair, KeyType='rsa' ) except ClientError as error: error = str(error) if '(InvalidKeyPair.Duplicate)' in error: self.logger.warning('Found an existing KeyPair named: %s. Re-creating it.', - self.env.key_pair) + config.env.key_pair) self._delete_key_pair() return self._create_key_pair() self.logger.warning('API call to create key pair has failed.') self.logger.error(error) return False - with open(self.settings.key_pair_file, 'w') as file: + with open(config.settings.key_pair_file, 'w') as file: file.write(key_pair.key_material) file.flush() - self.logger.info('Stored KeyPair as %s', self.settings.key_pair_file) + self.logger.info('Stored KeyPair as %s', config.settings.key_pair_file) return True def _get_vpc_id(self) -> Union[str, None]: @@ -203,8 +199,8 @@ def _authorize_security_group(self, 'ToPort': 443, 'IpRanges': [{'CidrIp': '0.0.0.0/0'}]}, {'IpProtocol': 'tcp', - 'FromPort': self.env.vpn_port, - 'ToPort': self.env.vpn_port, + 'FromPort': config.env.vpn_port, + 'ToPort': config.env.vpn_port, 'IpRanges': [{'CidrIp': '0.0.0.0/0'}]}, {'IpProtocol': 'tcp', 'FromPort': 945, @@ -249,16 +245,16 @@ def _create_security_group(self) -> Union[str, None]: try: security_group = self.ec2_resource.create_security_group( - GroupName=self.env.security_group, + GroupName=config.env.security_group, Description='Security Group to allow certain port ranges for exposing localhost to public internet.', VpcId=vpc_id ) except ClientError as error: error = str(error) - if '(InvalidGroup.Duplicate)' in error and self.env.security_group in error: + if '(InvalidGroup.Duplicate)' in error and config.env.security_group in error: security_groups = list(self.ec2_resource.security_groups.all()) for security_group in security_groups: - if security_group.group_name == self.env.security_group: + if security_group.group_name == config.env.security_group: self.logger.info("Re-using existing SecurityGroup '%s'", security_group.group_id) return security_group.group_id raise RuntimeError('Duplicate raised, but no such SG found.') @@ -288,8 +284,8 @@ def _create_ec2_instance(self) -> Union[Tuple[str, str], None]: ImageId=self.image_id, MinCount=1, MaxCount=1, - InstanceType=self.env.instance_type, - KeyName=self.env.key_pair, + InstanceType=config.env.instance_type, + KeyName=config.env.key_pair, SecurityGroupIds=[security_group_id] ) instance = instances[0] # Get the first (and only) instance @@ -312,20 +308,20 @@ def _delete_key_pair(self) -> bool: Flag to indicate the calling function, if the KeyPair was deleted successfully. """ try: - key_pair = self.ec2_resource.KeyPair(self.env.key_pair) + key_pair = self.ec2_resource.KeyPair(config.env.key_pair) key_pair.delete() except ClientError as error: - self.logger.warning("API call to delete the key '%s' has failed.", self.env.key_pair) + self.logger.warning("API call to delete the key '%s' has failed.", config.env.key_pair) self.logger.error(error) return False - self.logger.info('%s has been deleted from KeyPairs.', self.env.key_pair) + self.logger.info('%s has been deleted from KeyPairs.', config.env.key_pair) # Delete the associated .pem file if it exists - if os.path.exists(self.settings.key_pair_file): - os.chmod(self.settings.key_pair_file, int('700', base=8) or 0o700) - os.remove(self.settings.key_pair_file) - self.logger.info(f'Removed {self.settings.key_pair_file}.') + if os.path.exists(config.settings.key_pair_file): + os.chmod(config.settings.key_pair_file, int('700', base=8) or 0o700) + os.remove(config.settings.key_pair_file) + self.logger.info(f'Removed {config.settings.key_pair_file}.') return True def _disassociate_security_group(self, @@ -423,7 +419,7 @@ def _test_get(self, host: str, timeout: Tuple = (3, 3), retries: int = 5) -> Res """ for i in range(1, retries + 1): try: - response = requests.get(url=f"https://{host}:{self.env.vpn_port}", + response = requests.get(url=f"https://{host}:{config.env.vpn_port}", verify=False, timeout=timeout) self.logger.debug(response) return response @@ -457,9 +453,9 @@ def _tester(self, data: Dict[str, Union[str, int]]) -> None: """ urllib3.disable_warnings(InsecureRequestWarning) # Disable warnings for self-signed certificates alias_thread = None - if self.settings.entrypoint: + if config.settings.entrypoint: alias_thread = ThreadPool(processes=1).apply_async(self._test_get, - args=(self.settings.entrypoint,)) + args=(config.settings.entrypoint,)) self.logger.info("Testing GET connections to VPN server, via hostname and IP address.") ip_thread = ThreadPool(processes=1).apply_async(self._test_get, kwds=dict(host=data.get('public_ip'), retries=2)) @@ -471,24 +467,23 @@ def _tester(self, data: Dict[str, Union[str, int]]) -> None: "One or more tests for GET connection has failed. Please check the logs for more information." self.logger.info("Connections to VPN server, via hostname and IP address were successful.") self.logger.info("Testing SSH connection to %s", data.get('public_dns')) - test_ssh = Server(username=self.env.vpn_username, hostname=data.get('public_dns'), logger=self.logger, - env=self.env, settings=self.settings) + test_ssh = server.Server(username=config.env.vpn_username, hostname=data.get('public_dns'), logger=self.logger) test_ssh.test_service(display=False, timeout=5) self.logger.info(f"SSH to {data.get('public_dns')} was successful.") if alias_thread: if (alias_response := alias_thread.get()) and alias_response.ok: self.logger.info("Connection to VPN server, via alias record %s was successful.", - self.settings.entrypoint) + config.settings.entrypoint) else: self.logger.error("Failed to test A record, it may be DNS propagation delay. ") def test_vpn(self) -> None: """Tests the ``GET`` and ``SSH`` connections to an existing VPN server.""" try: - with open(self.env.vpn_info) as file: + with open(config.env.vpn_info) as file: data_exist = json.load(file) except FileNotFoundError: - self.logger.error(f'Input file: {self.env.vpn_info} is missing. CANNOT proceed.') + self.logger.error(f'Input file: {config.env.vpn_info} is missing. CANNOT proceed.') return self._tester(data=data_exist) @@ -499,12 +494,12 @@ def create_vpn_server(self) -> Union[Dict[str, Union[str, int]], None]: Dict[str, Union[str, int]] or None: VPN access server information. """ - if os.path.isfile(self.env.vpn_info) and os.path.isfile(self.settings.key_pair_file): + if os.path.isfile(config.env.vpn_info) and os.path.isfile(config.settings.key_pair_file): self.logger.warning('Received request to start VM, but looks like a session is up and running already.') self.logger.warning('Initiating re-configuration.') - with open(self.env.vpn_info) as file: + with open(config.env.vpn_info) as file: data = json.load(file) - self.env.image_id = 'ami-0000000000' # placeholder value since this won't be used in re-configuration + config.env.image_id = 'ami-0000000000' # placeholder value since this won't be used in re-configuration self._init(True) try: self._tester(data) @@ -528,7 +523,7 @@ def create_vpn_server(self) -> Union[Dict[str, Union[str, int]], None]: warnings.warn( "Failed on waiting for instance to enter 'running' state, please raise an issue at:\n" "https://github.com/thevickypedia/vpn-server/issues", - NotImplementedWarning + exceptions.NotImplementedWarning ) self._delete_key_pair() # No need to wait for SG disassociation since this is a handler for a WaiterError already @@ -553,40 +548,41 @@ def create_vpn_server(self) -> Union[Dict[str, Union[str, int]], None]: warnings.warn( "Failed on waiting for instance to enter 'running' state, please raise an issue at:\n" "https://github.com/thevickypedia/vpn-server/issues", - NotImplementedWarning + exceptions.NotImplementedWarning ) self._delete_security_group(security_group_id) return instance_info = { - 'port': self.env.vpn_port, + 'port': config.env.vpn_port, 'instance_id': instance_id, 'public_dns': instance.public_dns_name, 'public_ip': instance.public_ip_address, 'security_group_id': security_group_id, - 'ssh_endpoint': f'ssh -i {self.settings.key_pair_file} openvpnas@{instance.public_dns_name}' + 'ssh_endpoint': f'ssh -i {config.settings.key_pair_file} openvpnas@{instance.public_dns_name}' } - os.chmod(self.settings.key_pair_file, int('400', base=8) or 0o400) + if config.settings.entrypoint: + if route53.change_record_set(source=config.settings.entrypoint, + destination=instance.public_ip_address, + logger=self.logger, + client=self.route53_client, + zone_id=self.zone_id, action='UPSERT'): + instance_info['entrypoint'] = config.settings.entrypoint + else: + self.logger.error("Failed to add entrypoint as alias") + config.settings.entrypoint = None - with open(self.env.vpn_info, 'w') as file: + os.chmod(config.settings.key_pair_file, int('400', base=8) or 0o400) + + with open(config.env.vpn_info, 'w') as file: json.dump(instance_info, file, indent=2) file.flush() self._configure_vpn(instance.public_dns_name) - if self.settings.entrypoint: - if change_record_set(source=self.settings.entrypoint, - destination=instance.public_ip_address, - logger=self.logger, - client=self.route53_client, - zone_id=self.zone_id, action='UPSERT'): - instance_info['entrypoint'] = self.settings.entrypoint - with open(self.env.vpn_info, 'w') as file: - json.dump(instance_info, file, indent=2) - file.flush() - else: - self.logger.error("Failed to add entrypoint as alias") - self.settings.entrypoint = None + with open(config.env.vpn_info, 'w') as file: + json.dump(instance_info, file, indent=2) + file.flush() try: self._tester(data=instance_info) @@ -594,7 +590,7 @@ def create_vpn_server(self) -> Union[Dict[str, Union[str, int]], None]: self.logger.error('Failed to configure VPN server. Please check the logs for more information.') else: self.logger.info('VPN server has been configured successfully. Details have been stored in %s.', - self.env.vpn_info) + config.env.vpn_info) return instance_info def _configure_vpn(self, public_dns: str) -> None: @@ -608,8 +604,7 @@ def _configure_vpn(self, public_dns: str) -> None: # Max of 10 iterations with 5 second interval between each iteration with default timeout for i in range(10): try: - server = Server(hostname=public_dns, username='openvpnas', logger=self.logger, - env=self.env, settings=self.settings) + vpn_server = server.Server(hostname=public_dns, username='openvpnas', logger=self.logger) self.logger.info("Connection established on %s attempt", self.engine.ordinal(i + 1)) break except Exception as error: @@ -620,7 +615,7 @@ def _configure_vpn(self, public_dns: str) -> None: raise TimeoutError( "Unable to connect SSH server, please call the 'start' function once again if instance looks healthy" ) - server.run_interactive_ssh() + vpn_server.run_interactive_ssh() def delete_vpn_server(self, instance_id: str = None, @@ -641,11 +636,11 @@ def delete_vpn_server(self, | wait_until_terminated.html """ try: - with open(self.env.vpn_info) as file: + with open(config.env.vpn_info) as file: data = json.load(file) except FileNotFoundError: assert instance_id and security_group_id, \ - (f"\n\nInput file: {self.env.vpn_info!r} is missing. " + (f"\n\nInput file: {config.env.vpn_info!r} is missing. " "Arguments 'instance_id' and 'security_group_id' are required to proceed.") data = {} self._init(False) @@ -656,9 +651,10 @@ def delete_vpn_server(self, self._delete_key_pair() sg_association = self._disassociate_security_group(instance_id=instance_id, security_group_id=security_group_id) instance = self._terminate_ec2_instance(instance_id=instance_id) - if self.env.hosted_zone and self.env.subdomain and public_ip: - change_record_set(source=self.settings.entrypoint, destination=public_ip, - logger=self.logger, client=self.route53_client, zone_id=self.zone_id, action='DELETE') + if config.settings.entrypoint and public_ip: + route53.change_record_set(source=config.settings.entrypoint, destination=public_ip, + logger=self.logger, client=self.route53_client, zone_id=self.zone_id, + action='DELETE') if not sg_association and instance: try: instance.wait_until_terminated( @@ -667,4 +663,4 @@ def delete_vpn_server(self, except WaiterError as error: self.logger.error(error) self._delete_security_group(security_group_id) - os.remove(self.env.vpn_info) if os.path.isfile(self.env.vpn_info) else None + os.remove(config.env.vpn_info) if os.path.isfile(config.env.vpn_info) else None diff --git a/vpn/models/config.py b/vpn/models/config.py index 97a76d5..fc606d8 100644 --- a/vpn/models/config.py +++ b/vpn/models/config.py @@ -93,6 +93,9 @@ def validate_instance_type(cls, v: str) -> str: return v +env = EnvConfig + + class Settings(BaseModel): """Wrapper for configuration settings. @@ -105,7 +108,10 @@ class Settings(BaseModel): openvpn_config_commands: List[ConfigurationSettings] = [] -def configuration_dict(env: EnvConfig) -> List[ConfigurationSettings]: +settings = Settings() + + +def configuration_dict(param: EnvConfig) -> List[ConfigurationSettings]: """Get configuration interaction as a list of dictionaries.""" for config_dict in [ {'request': "Please enter 'yes' to indicate your agreement \\[no\\]: ", 'response': 'yes', 'timeout': 5, @@ -118,7 +124,7 @@ def configuration_dict(env: EnvConfig) -> List[ConfigurationSettings]: {'request': '> Press ENTER for default \\[rsa\\]:', 'response': 'rsa', 'timeout': 1, 'critical': False}, {'request': '> Press ENTER for default \\[ 2048 \\]:', 'response': '2048', 'timeout': 1, 'critical': False}, - {'request': '> Press ENTER for default \\[943\\]: ', 'response': env.vpn_port, 'timeout': 1, + {'request': '> Press ENTER for default \\[943\\]: ', 'response': param.vpn_port, 'timeout': 1, 'critical': False}, {'request': '> Press ENTER for default \\[443\\]: ', 'response': '443', 'timeout': 1, 'critical': False}, {'request': '> Press ENTER for default \\[no\\]: ', 'response': 'yes', 'timeout': 1, 'critical': False}, @@ -127,11 +133,11 @@ def configuration_dict(env: EnvConfig) -> List[ConfigurationSettings]: 'critical': False}, {'request': '> Press ENTER for default \\[yes\\]: ', 'response': 'no', 'timeout': 1, 'critical': False}, {'request': '> Specify the username for an existing user or for the new user account: ', - 'response': env.vpn_username, 'timeout': 1, 'critical': True}, - {'request': f"Type a password for the '{env.vpn_username}' account " + 'response': param.vpn_username, 'timeout': 1, 'critical': True}, + {'request': f"Type a password for the '{param.vpn_username}' account " "(if left blank, a random password will be generated):", - 'response': env.vpn_password, 'timeout': 1, 'critical': True}, - {'request': f"Confirm the password for the '{env.vpn_username}' account:", 'response': env.vpn_password, + 'response': param.vpn_password, 'timeout': 1, 'critical': True}, + {'request': f"Confirm the password for the '{param.vpn_username}' account:", 'response': param.vpn_password, 'timeout': 1, 'critical': True}, {'request': '> Please specify your Activation key (or leave blank to specify later): ', 'response': '\n', 'timeout': 1, 'critical': False} diff --git a/vpn/models/route53.py b/vpn/models/route53.py index af813e0..a96cc97 100644 --- a/vpn/models/route53.py +++ b/vpn/models/route53.py @@ -5,7 +5,7 @@ import boto3 from botocore.exceptions import ClientError -from vpn.models.exceptions import AWSResourceError +from vpn.models import exceptions def get_zone_id(client: boto3.client, @@ -34,7 +34,7 @@ def get_zone_id(client: boto3.client, logger.error(response) if init: status_code = response.get('ResponseMetadata', {}).get('HTTPStatusCode', 500) - raise AWSResourceError(status_code, http_response[status_code]) + raise exceptions.AWSResourceError(status_code, http_response[status_code]) return if hosted_zones := response.get('HostedZones'): @@ -42,7 +42,7 @@ def get_zone_id(client: boto3.client, if hosted_zone['Name'] in (dns, f'{dns}.'): return hosted_zone['Id'].split('/')[-1] if init: - raise AWSResourceError(404, f'No HostedZones found for the DNSName: {dns}') + raise exceptions.AWSResourceError(404, f'No HostedZones found for the DNSName: {dns}') logger.error(f'No HostedZones found for the DNSName: {dns}\n{response}') diff --git a/vpn/models/server.py b/vpn/models/server.py index 277493f..9173530 100644 --- a/vpn/models/server.py +++ b/vpn/models/server.py @@ -7,7 +7,7 @@ from paramiko.ssh_exception import AuthenticationException from paramiko_expect import SSHClientInteraction -from vpn.models.config import EnvConfig, Settings +from vpn.models import config class Server: @@ -17,24 +17,22 @@ class Server: """ - def __init__(self, hostname: str, username: str, logger: logging.Logger, env: EnvConfig, settings: Settings): + def __init__(self, hostname: str, username: str, logger: logging.Logger): """Instantiates the session using RSAKey generated from a ``***.pem`` file. Args: hostname: Hostname of the server. """ self.logger = logger - self.env = env - self.settings = settings - pem_key = RSAKey.from_private_key_file(filename=settings.key_pair_file) + pem_key = RSAKey.from_private_key_file(filename=config.settings.key_pair_file) self.ssh_client = SSHClient() self.ssh_client.load_system_host_keys() self.ssh_client.set_missing_host_key_policy(policy=AutoAddPolicy()) - if username == self.env.vpn_username: + if username == config.env.vpn_username: try: # todo: Manual config accepts username and password, but unable to get authentication pass via paramiko self.ssh_client.connect(hostname=hostname, username=username, - pkey=pem_key, password=self.env.vpn_password) + pkey=pem_key, password=config.env.vpn_password) except AuthenticationException as error: self.logger.warning(error) self.ssh_client.connect(hostname=hostname, username='openvpnas', pkey=pem_key) @@ -102,7 +100,7 @@ def run_interactive_ssh(self, timeout=timeout, display=display, output_callback=lambda msg: self.logger.info(msg)) as interact: - for setting in self.settings.openvpn_config_commands: + for setting in config.settings.openvpn_config_commands: interact.expect(re_strings=setting.request, timeout=setting.timeout) interact.send(send_string=str(setting.response)) # Blank to await final steps of configuration