diff --git a/src/vmcontrol/sessionhandler/sessionhandler.py b/src/vmcontrol/sessionhandler/sessionhandler.py index 7fbcadb..2bc450a 100644 --- a/src/vmcontrol/sessionhandler/sessionhandler.py +++ b/src/vmcontrol/sessionhandler/sessionhandler.py @@ -40,7 +40,8 @@ class SessionConfig(DictNamespace): "Company Router", "Log Server", "Internal Server", - "DMZ Server"] + "DMZ Server", + ] client_vm = "Client" number_of_clones = 3 vm_start_timeout = 0 @@ -54,14 +55,18 @@ class Clone(DictNamespace): user = None password = None domain = None + vrde_port = None class CloneCreator: - def __init__(self, father_vm, base_snapshot, number_of_clones, vmm_controller: VMMController): - self.father_vm = father_vm + def __init__( + self, parent_vm, base_snapshot, number_of_clones, vmm_controller: VMMController, vrde_port_start=5000 + ): + self.parent_vm = parent_vm self.base_snapshot = base_snapshot self.number_of_clones = number_of_clones self.vmmc = vmm_controller + self.vrde_port_start = vrde_port_start self._current_vms = self.vmmc.get_vms() self._clones = None @@ -77,9 +82,10 @@ def set_data(self, clone): self.set_management_mac(clone) self.set_internal_mac(clone) self.set_credentials(clone) + self.set_vrde_port(clone) def set_vm_name(self, clone): - clone.vm = self.father_vm + "Clone" + str(clone.id) + clone.vm = self.parent_vm + "Clone" + str(clone.id) while clone.vm in self._current_vms: clone.vm += "a" @@ -94,10 +100,14 @@ def set_credentials(self, clone): clone.password = "breach" clone.domain = "BREACH" + def set_vrde_port(self, clone): + clone.vrde_port = self.vrde_port_start + clone.id - 1 + def create_vm(self, clone): - self.vmmc.clone(self.father_vm, self.base_snapshot, clone.vm) + self.vmmc.clone(self.parent_vm, self.base_snapshot, clone.vm) self.vmmc.set_mac(clone.vm, clone.management_mac, if_id=2) self.vmmc.set_mac(clone.vm, clone.internal_mac, if_id=1) + self.vmmc.set_vrde_port(clone.vm, clone.vrde_port) class SessionHandler: @@ -192,10 +202,9 @@ def create_backup_snapshots(self, vms): def create_clones(self): logger.info("Creating clones") - father_vm = self.config.client_vm - base_snapshot = self.backup_snapshots[father_vm] - clone_creator = self.clone_creator_class( - father_vm, base_snapshot, self.config.number_of_clones, self.vmmc) + parent_vm = self.config.client_vm + base_snapshot = self.backup_snapshots[parent_vm] + clone_creator = self.clone_creator_class(parent_vm, base_snapshot, self.config.number_of_clones, self.vmmc) self.clones = clone_creator.create() def start_all_vms(self): @@ -256,8 +265,8 @@ def restore_delete_snapshots(self, snapshots): time.sleep(timeout) if snaps: raise SessionHandlerException( - "Could not restore and delete all snapshots. Failing: {snaps}" - .format(snaps=snaps)) + "Could not restore and delete all snapshots. Failing: {snaps}".format(snaps=snaps) + ) def poweroff_vms(self, vms): fails, max_fails, timeout = 0, 100, 0.01 @@ -272,7 +281,8 @@ def poweroff_vms(self, vms): time.sleep(timeout) if running_vms: raise SessionHandlerException( - "Could not poweroff all machines. Still alive: {vms}".format(vms=running_vms)) + "Could not poweroff all machines. Still alive: {vms}".format(vms=running_vms) + ) def delete_vms(self, vms): fails, max_fails, timeout = 0, 100, 0.01 @@ -287,8 +297,7 @@ def delete_vms(self, vms): existing_vms.append(vm) time.sleep(timeout) if existing_vms: - raise SessionHandlerException( - "Could not delete all machines. Still there: {vms}".format(vms=existing_vms)) + raise SessionHandlerException("Could not delete all machines. Still there: {vms}".format(vms=existing_vms)) def take_vm_start_timeout(self): time.sleep(self.config.vm_start_timeout) diff --git a/src/vmcontrol/sessionhandler/tests/test_sessionhandler.py b/src/vmcontrol/sessionhandler/tests/test_sessionhandler.py index 50d863c..cd85b09 100644 --- a/src/vmcontrol/sessionhandler/tests/test_sessionhandler.py +++ b/src/vmcontrol/sessionhandler/tests/test_sessionhandler.py @@ -21,8 +21,7 @@ import pytest -from vmcontrol.sessionhandler import SessionHandler, SessionConfig, SessionHandlerException, \ - CloneCreator +from vmcontrol.sessionhandler import SessionHandler, SessionConfig, SessionHandlerException, CloneCreator from vmcontrol.sessionhandler.sessionhandler import Clone, DictNamespace from vmcontrol.vmmcontroller import VMMControllerException from vmcontrol.vmmcontroller.tests.mocks import MockVMMController @@ -109,8 +108,13 @@ class TestSessionHandler: def test_default_config(self): config = SessionHandler.default_config() assert config.server_vms == [ - "Internet Router", "Attacker", "Company Router", "Log Server", "Internal Server", - "DMZ Server"] + "Internet Router", + "Attacker", + "Company Router", + "Log Server", + "Internal Server", + "DMZ Server", + ] assert config.client_vm == "Client" assert config.number_of_clones == 3 assert config.vm_start_timeout == 0 @@ -209,8 +213,7 @@ def test_login_clones(self, sh: SessionHandler): sh.create_clones() sh.login_clones() for clone in sh.clones: - sh.vmmc.set_credentials.assert_any_call( - clone.vm, clone.user, clone.password, clone.domain) + sh.vmmc.set_credentials.assert_any_call(clone.vm, clone.user, clone.password, clone.domain) def test_start_session(self, sh: SessionHandler): server_vms = sh.config.server_vms @@ -309,16 +312,16 @@ def test_init_default_fields(self): @pytest.fixture() def cc(): - father_vm = "Client" + parent_vm = "Client" base_snapshot = "CloneShot" number_of_clones = 5 mvmmc = MockVMMController() some_vm = mvmmc.get_vms()[0] some_shot = "Snap" mvmmc.create_snapshot(some_vm, some_shot) - mvmmc.clone(some_vm, some_shot, father_vm) - mvmmc.create_snapshot(father_vm, base_snapshot) - cc = CloneCreator(father_vm, base_snapshot, number_of_clones, mvmmc) + mvmmc.clone(some_vm, some_shot, parent_vm) + mvmmc.create_snapshot(parent_vm, base_snapshot) + cc = CloneCreator(parent_vm, base_snapshot, number_of_clones, mvmmc) return cc @@ -334,7 +337,7 @@ def test_generate_ids(self, cc: CloneCreator): def test_set_vm_name(self, cc: CloneCreator): clone = Clone(id=42) cc.set_vm_name(clone) - assert clone.vm.startswith(cc.father_vm) + assert clone.vm.startswith(cc.parent_vm) def test_set_vm_name_avoiding_current_vms(self, cc: CloneCreator): clone = Clone(id=42) diff --git a/src/vmcontrol/vmmcontroller/tests/mocks.py b/src/vmcontrol/vmmcontroller/tests/mocks.py index 1c9a732..c56a547 100644 --- a/src/vmcontrol/vmmcontroller/tests/mocks.py +++ b/src/vmcontrol/vmmcontroller/tests/mocks.py @@ -25,8 +25,11 @@ def printing_method(*args, **kwargs): self = args[0] if self.printing: method_name = str(method.__name__)[6:] - print("called \"{method_name}\" with {args} and {kwargs}" - .format(method_name=method_name, args=args[1:], kwargs=kwargs)) + print( + 'called "{method_name}" with {args} and {kwargs}'.format( + method_name=method_name, args=args[1:], kwargs=kwargs + ) + ) return method(*args, **kwargs) return printing_method @@ -40,10 +43,7 @@ class VM: def __init__(self, name): self.name = name or "" self.snapshots = list() - self.macs = [ - int("0080123456AA", 16) + self.mac_count, - int("00801BB456AA", 16) + self.mac_count - ] + self.macs = [int("0080123456AA", 16) + self.mac_count, int("00801BB456AA", 16) + self.mac_count] VM.mac_count += 1 @@ -139,3 +139,7 @@ def set_mac(self, vm, mac, if_id=1): @mock_print_decorator() def set_credentials(self, vm, user, password, domain): pass + + @mock_print_decorator() + def set_vrde_port(self, vm, port): + pass diff --git a/src/vmcontrol/vmmcontroller/tests/test_vboxcontroller.py b/src/vmcontrol/vmmcontroller/tests/test_vboxcontroller.py index e03cec9..90059fe 100644 --- a/src/vmcontrol/vmmcontroller/tests/test_vboxcontroller.py +++ b/src/vmcontrol/vmmcontroller/tests/test_vboxcontroller.py @@ -39,11 +39,13 @@ def test_vboxmanage_mock(self, vbc: VBoxController): assert isinstance(vbc._vboxmanage_execute, Mock) def test_get_vms(self, vbc: VBoxController): - return_value = "\n".join([ - "\"Attacker\" {8451900b-320a-43b4-9eb9-9bd6656f33ad}", - "\"Client\" {4d0986c7-eabc-4cd3-a2f3-e28111a66ac1}", - "\"Company Router\" {d75d3108-266b-420e-aa74-bc73f689ee9c}" - ]) + return_value = "\n".join( + [ + '"Attacker" {8451900b-320a-43b4-9eb9-9bd6656f33ad}', + '"Client" {4d0986c7-eabc-4cd3-a2f3-e28111a66ac1}', + '"Company Router" {d75d3108-266b-420e-aa74-bc73f689ee9c}', + ] + ) vbc._vboxmanage_execute = Mock(return_value=return_value) vms = vbc.get_vms() assert_call(vbc, ["list", "vms"]) @@ -67,11 +69,13 @@ def test_is_running(self, vbc: VBoxController): assert not vbc.is_running("VM2") def test_get_running_vms(self, vbc: VBoxController): - return_value = "\n".join([ - "\"Attacker\" {8451900b-320a-43b4-9eb9-9bd6656f33ad}", - "\"Client\" {4d0986c7-eabc-4cd3-a2f3-e28111a66ac1}", - "\"Company Router\" {d75d3108-266b-420e-aa74-bc73f689ee9c}" - ]) + return_value = "\n".join( + [ + '"Attacker" {8451900b-320a-43b4-9eb9-9bd6656f33ad}', + '"Client" {4d0986c7-eabc-4cd3-a2f3-e28111a66ac1}', + '"Company Router" {d75d3108-266b-420e-aa74-bc73f689ee9c}', + ] + ) vbc._vboxmanage_execute = Mock(return_value=return_value) vms = vbc._get_running_vms() assert_call(vbc, ["list", "runningvms"]) @@ -89,7 +93,7 @@ def test_get_macs(self, vbc: VBoxController): return_value = { "macaddress1": "0800278144CB", "macaddress2": "080027859405", - "othervalue": "ABAB" + "othervalue": "ABAB", } vbc._get_vm_info = Mock(return_value=return_value) macs = vbc.get_macs("VM") @@ -100,7 +104,7 @@ def test_get_mac(self, vbc: VBoxController): return_value = { "macaddress1": "0800278144CB", "macaddress2": "080027859405", - "othervalue": "ABAB" + "othervalue": "ABAB", } vbc._get_vm_info = Mock(return_value=return_value) mac = vbc.get_mac("VM", if_id=1) @@ -111,7 +115,7 @@ def test_get_mac_with_if_id(self, vbc: VBoxController): return_value = { "macaddress1": "0800278144CB", "macaddress2": "080027859405", - "othervalue": "ABAB" + "othervalue": "ABAB", } vbc._get_vm_info = Mock(return_value=return_value) mac = vbc.get_mac("VM", if_id=2) @@ -122,7 +126,7 @@ def test_get_mac_exception(self, vbc: VBoxController): return_value = { "macaddress1": "0800278144CB", "macaddress2": "080027859405", - "othervalue": "ABAB" + "othervalue": "ABAB", } vbc._get_vm_info = Mock(return_value=return_value) with pytest.raises(VMMControllerException) as ei: @@ -156,6 +160,7 @@ def vbox_causes_exception(vbox_vector): raise VMMControllerException("Some text....") else: pass + vbc._vboxmanage_execute = Mock(side_effect=vbox_causes_exception) # get_vms may be used for hack vbc.get_vms = Mock(return_value=[vm_without_snapshot]) @@ -176,18 +181,13 @@ def test_restore_snapshot(self, vbc: VBoxController): def test_clone(self, vbc: VBoxController): vbc.clone("VM", "Snapshot", "VMClone") - assert_call(vbc, [ - "clonevm", "VM", "--name", "VMClone", - "--options", "link", "--snapshot", "Snapshot", - "--register" - ]) + assert_call( + vbc, ["clonevm", "VM", "--name", "VMClone", "--options", "link", "--snapshot", "Snapshot", "--register"] + ) def test_set_credentials(self, vbc: VBoxController): vbc.set_credentials("VM", "TheUser", "ThePassword", "TheDomain") - assert_call(vbc, [ - "controlvm", "VM", - "setcredentials", "TheUser", "ThePassword", "TheDomain" - ]) + assert_call(vbc, ["controlvm", "VM", "setcredentials", "TheUser", "ThePassword", "TheDomain"]) def _some_vbox_info_output(self): return """name="Company Router" @@ -339,7 +339,7 @@ def _some_snapshot_output(self): """ -["VBoxManage", "clonevm", self.father, "--name", c.vm_name, +["VBoxManage", "clonevm", self.parent, "--name", c.vm_name, "--options", "link", "--snapshot", "CloneSnapshot", "--register"]) -""" \ No newline at end of file +""" diff --git a/src/vmcontrol/vmmcontroller/vboxcontroller.py b/src/vmcontrol/vmmcontroller/vboxcontroller.py index 134c093..604d453 100644 --- a/src/vmcontrol/vmmcontroller/vboxcontroller.py +++ b/src/vmcontrol/vmmcontroller/vboxcontroller.py @@ -45,11 +45,7 @@ def is_running(self, vm): def get_macs(self, vm): vm_info = self._get_vm_info(vm) - macs = [ - int(value, 16) - for key, value in vm_info.items() - if key.startswith("macaddress") - ] + macs = [int(value, 16) for key, value in vm_info.items() if key.startswith("macaddress")] return macs def get_mac(self, vm, if_id=1): @@ -58,8 +54,7 @@ def get_mac(self, vm, if_id=1): mac_string = vm_info["macaddress" + str(if_id)] except KeyError as e: raise VMMControllerException( - "Cannot find MAC address for interface {if_id} on VM \"{vm}\"". - format(if_id=if_id, vm=vm) + 'Cannot find MAC address for interface {if_id} on VM "{vm}"'.format(if_id=if_id, vm=vm) ) mac = int(mac_string, 16) return mac @@ -86,10 +81,10 @@ def get_snapshots(self, vm): out_lines = out.splitlines() snapshot_lines = filter(lambda line: "=" in line, out_lines) snapshots = [ - value.strip("\"") + value.strip('"') for key, value in map(lambda line: line.split("=", maxsplit=1), snapshot_lines) - if key.strip("\"").startswith("SnapshotName") - ] + if key.strip('"').startswith("SnapshotName") + ] return snapshots def create_snapshot(self, vm, snapshot): @@ -105,18 +100,15 @@ def restore_snapshot(self, vm, snapshot): self._vboxmanage_execute(vbox_vector) def clone(self, vm, snapshot, clone): - vbox_vector = [ - "clonevm", vm, "--name", clone, - "--options", "link", "--snapshot", snapshot, - "--register" - ] + vbox_vector = ["clonevm", vm, "--name", clone, "--options", "link", "--snapshot", snapshot, "--register"] self._vboxmanage_execute(vbox_vector) def set_credentials(self, vm, user, password, domain): - vbox_vector = [ - "controlvm", vm, - "setcredentials", user, password, domain - ] + vbox_vector = ["controlvm", vm, "setcredentials", user, password, domain] + self._vboxmanage_execute(vbox_vector) + + def set_vrde_port(self, vm, port): + vbox_vector = ["modifyvm", vm, "--vrdeport", str(port)] self._vboxmanage_execute(vbox_vector) def _get_running_vms(self): @@ -127,7 +119,7 @@ def _get_running_vms(self): def _vm_string_to_list(self, vm_string): lines = vm_string.splitlines() - vms = [line.split("\"")[1] for line in lines] + vms = [line.split('"')[1] for line in lines] return vms def _get_vm_info(self, vm): @@ -136,10 +128,9 @@ def _get_vm_info(self, vm): out_lines = out.splitlines() info_lines = filter(lambda line: "=" in line, out_lines) vm_info = { - key.strip("\""): value.strip("\"") - for key, value in - map(lambda line: line.split("=", maxsplit=1), info_lines) - } + key.strip('"'): value.strip('"') + for key, value in map(lambda line: line.split("=", maxsplit=1), info_lines) + } return vm_info @staticmethod @@ -149,11 +140,7 @@ def _vboxmanage_execute(vbox_vector): out, err = p.communicate() if p.returncode != 0: raise VMMControllerException( - "Error in execution of {vector}\n" - "-------\n" - "{err}" - "-------" - .format(vector=vbox_vector, err=err) + "Error in execution of {vector}\n" "-------\n" "{err}" "-------".format(vector=vbox_vector, err=err) ) return out diff --git a/src/vmcontrol/vmmcontroller/vmmcontroller.py b/src/vmcontrol/vmmcontroller/vmmcontroller.py index a701120..f37523a 100644 --- a/src/vmcontrol/vmmcontroller/vmmcontroller.py +++ b/src/vmcontrol/vmmcontroller/vmmcontroller.py @@ -64,6 +64,9 @@ def clone(self, vm, snapshot, clone): def set_credentials(self, vm, user, password, domain): raise NotImplementedError() + def set_vrde_port(self, vm, port): + raise NotImplementedError() + class VMMControllerException(Exception): pass @@ -71,41 +74,42 @@ class VMMControllerException(Exception): class LoggingVMMController(VMMController): def start(self, vm): - logger.debug("Starting \"{vm}\"".format(vm=vm)) + logger.debug('Starting "{vm}"'.format(vm=vm)) super().start(vm) def poweroff(self, vm): - logger.debug("Turning off \"{vm}\"".format(vm=vm)) + logger.debug('Turning off "{vm}"'.format(vm=vm)) super().poweroff(vm) def delete(self, vm): - logger.debug("Deleting \"{vm}\"".format(vm=vm)) + logger.debug('Deleting "{vm}"'.format(vm=vm)) super().delete(vm) def set_mac(self, vm, mac, if_id=1): mac_string = hex(mac)[2:].rjust(12, "0") - logger.debug("Setting MAC address of \"{vm}\" interface {if_id} to {mac}" - .format(vm=vm, if_id=if_id, mac=mac_string)) + logger.debug( + 'Setting MAC address of "{vm}" interface {if_id} to {mac}'.format(vm=vm, if_id=if_id, mac=mac_string) + ) super().set_mac(vm, mac, if_id=if_id) def create_snapshot(self, vm, snapshot): - logger.debug("Creating snapshot \"{snapshot}\" for \"{vm}\"".format(vm=vm, snapshot=snapshot)) + logger.debug('Creating snapshot "{snapshot}" for "{vm}"'.format(vm=vm, snapshot=snapshot)) super().create_snapshot(vm, snapshot) def delete_snapshot(self, vm, snapshot): - logger.debug("Deleting snapshot \"{snapshot}\" of \"{vm}\"".format(vm=vm, snapshot=snapshot)) + logger.debug('Deleting snapshot "{snapshot}" of "{vm}"'.format(vm=vm, snapshot=snapshot)) super().delete_snapshot(vm, snapshot) def restore_snapshot(self, vm, snapshot): - logger.debug("Restoring snapshot \"{snapshot}\" of \"{vm}\"".format(vm=vm, snapshot=snapshot)) + logger.debug('Restoring snapshot "{snapshot}" of "{vm}"'.format(vm=vm, snapshot=snapshot)) super().restore_snapshot(vm, snapshot) def clone(self, vm, snapshot, clone): - logger.debug("Cloning \"{vm}\" with snapshot \"{snapshot}\" to \"{clone}\"" - .format(vm=vm, snapshot=snapshot, clone=clone)) + logger.debug( + 'Cloning "{vm}" with snapshot "{snapshot}" to "{clone}"'.format(vm=vm, snapshot=snapshot, clone=clone) + ) super().clone(vm, snapshot, clone) def set_credentials(self, vm, user, password, domain): - logger.debug("Setting credentials of \"{vm}\" to {cred}" - .format(vm=vm, cred=(user, password, domain))) + logger.debug('Setting credentials of "{vm}" to {cred}'.format(vm=vm, cred=(user, password, domain))) super().set_credentials(vm, user, password, domain)