diff --git a/control/grpc.py b/control/grpc.py index 7d93abee..c5c61ceb 100644 --- a/control/grpc.py +++ b/control/grpc.py @@ -153,6 +153,8 @@ def __init__(self, config: GatewayConfig, gateway_state: GatewayStateHandler, om self.gateway_name = self.config.get("gateway", "name") if not self.gateway_name: self.gateway_name = socket.gethostname() + if not GatewayState.is_key_element_valid(self.gateway_name): + raise ValueError(f"Gateway name \"{self.gateway_name}\" contains invalid characters") self.gateway_group = self.config.get("gateway", "group") self.verify_nqns = self.config.getboolean_with_default("gateway", "verify_nqns", True) self.ana_map = defaultdict(dict) @@ -458,6 +460,11 @@ def create_subsystem_safe(self, request, context): f"Received request to create subsystem {request.subsystem_nqn}, enable_ha: {request.enable_ha}, context: {context}") errmsg = "" + if not GatewayState.is_key_element_valid(request.subsystem_nqn): + errmsg = f"{create_subsystem_error_prefix}: Invalid NQN \"{request.subsystem_nqn}\", contains invalid characters" + self.logger.error(f"{errmsg}") + return pb2.req_status(status = errno.EINVAL, error_message = errmsg) + if self.verify_nqns: rc = GatewayUtils.is_valid_nqn(request.subsystem_nqn) if rc[0] != 0: @@ -762,7 +769,7 @@ def set_ana_state_safe(self, ana_info: pb2.ana_info, context=None): nqn = nas.nqn if not self.get_subsystem_ha_status(nqn): continue - prefix = f"{self.gateway_state.local.LISTENER_PREFIX}{nqn}{GatewayState.OMAP_KEY_DELIMITER}{self.gateway_name}{GatewayState.OMAP_KEY_DELIMITER}" + prefix = GatewayState.build_partial_listener_key(nqn, self.gateway_name) + GatewayState.OMAP_KEY_DELIMITER listener_keys = [key for key in state.keys() if key.startswith(prefix)] self.logger.info(f"Iterate over {nqn=} {prefix=} {listener_keys=}") # fill the static gateway dictionary per nqn and grp_id @@ -1563,6 +1570,16 @@ def add_host_safe(self, request, context): all_host_failure_prefix=f"Failure allowing open host access to {request.subsystem_nqn}" host_failure_prefix=f"Failure adding host {request.host_nqn} to {request.subsystem_nqn}" + if not GatewayState.is_key_element_valid(request.host_nqn): + errmsg = f"{host_failure_prefix}: Invalid host NQN \"{request.host_nqn}\", contains invalid characters" + self.logger.error(f"{errmsg}") + return pb2.req_status(status = errno.EINVAL, error_message = errmsg) + + if not GatewayState.is_key_element_valid(request.subsystem_nqn): + errmsg = f"{host_failure_prefix}: Invalid subsystem NQN \"{request.subsystem_nqn}\", contains invalid characters" + self.logger.error(f"{errmsg}") + return pb2.req_status(status = errno.EINVAL, error_message = errmsg) + if self.verify_nqns: rc = GatewayService.is_valid_host_nqn(request.host_nqn) if rc.status != 0: @@ -1929,7 +1946,7 @@ def matching_listener_exists(self, context, nqn, gw_name, traddr, trsvcid) -> bo state = self.gateway_state.local.get_state() # We want to check for all the listeners for this address and port, regardless of the gateway key_prefix = GatewayState.build_partial_listener_key(nqn) - key_suffix = GatewayState.build_listener_key_suffix("", "TCP", traddr, trsvcid) + key_suffix = GatewayState.build_listener_key_suffix(None, "TCP", traddr, trsvcid) for key, val in state.items(): if not key.startswith(key_prefix): diff --git a/control/state.py b/control/state.py index 276dc680..a972a936 100644 --- a/control/state.py +++ b/control/state.py @@ -31,16 +31,23 @@ class GatewayState(ABC): LISTENER_PREFIX = "listener" + OMAP_KEY_DELIMITER NAMESPACE_QOS_PREFIX = "qos" + OMAP_KEY_DELIMITER + def is_key_element_valid(s: str) -> bool: + if type(s) != str: + return False + if GatewayState.OMAP_KEY_DELIMITER in s: + return False + return True + def build_namespace_key(subsystem_nqn: str, nsid) -> str: key = GatewayState.NAMESPACE_PREFIX + subsystem_nqn if nsid is not None: - key = key + GatewayState.OMAP_KEY_DELIMITER + str(nsid) + key += GatewayState.OMAP_KEY_DELIMITER + str(nsid) return key def build_namespace_qos_key(subsystem_nqn: str, nsid) -> str: key = GatewayState.NAMESPACE_QOS_PREFIX + subsystem_nqn if nsid is not None: - key = key + GatewayState.OMAP_KEY_DELIMITER + str(nsid) + key += GatewayState.OMAP_KEY_DELIMITER + str(nsid) return key def build_subsystem_key(subsystem_nqn: str) -> str: @@ -49,11 +56,14 @@ def build_subsystem_key(subsystem_nqn: str) -> str: def build_host_key(subsystem_nqn: str, host_nqn) -> str: key = GatewayState.HOST_PREFIX + subsystem_nqn if host_nqn is not None: - key = key + GatewayState.OMAP_KEY_DELIMITER + host_nqn + key += GatewayState.OMAP_KEY_DELIMITER + host_nqn return key - def build_partial_listener_key(subsystem_nqn: str) -> str: - return GatewayState.LISTENER_PREFIX + subsystem_nqn + def build_partial_listener_key(subsystem_nqn: str, gateway = None) -> str: + key = GatewayState.LISTENER_PREFIX + subsystem_nqn + if gateway: + key += GatewayState.OMAP_KEY_DELIMITER + gateway + return key def build_listener_key_suffix(gateway: str, trtype: str, traddr: str, trsvcid: int) -> str: if gateway: @@ -63,7 +73,7 @@ def build_listener_key_suffix(gateway: str, trtype: str, traddr: str, trsvcid: i return GatewayState.OMAP_KEY_DELIMITER + traddr + GatewayState.OMAP_KEY_DELIMITER + str(trsvcid) def build_listener_key(subsystem_nqn: str, gateway: str, trtype: str, traddr: str, trsvcid: int) -> str: - return GatewayState.build_partial_listener_key(subsystem_nqn) + GatewayState.build_listener_key_suffix(gateway, trtype, traddr, str(trsvcid)) + return GatewayState.build_partial_listener_key(subsystem_nqn, gateway) + GatewayState.build_listener_key_suffix(None, trtype, traddr, str(trsvcid)) @abstractmethod def get_state(self) -> Dict[str, str]: diff --git a/control/utils.py b/control/utils.py index 96e29a02..8cafb3ba 100644 --- a/control/utils.py +++ b/control/utils.py @@ -219,7 +219,7 @@ def __init__(self, config=None): log_files_rotation_enabled = False max_log_file_size = GatewayLogger.MAX_LOG_FILE_SIZE_DEFAULT max_log_files_count = GatewayLogger.MAX_LOG_FILES_COUNT_DEFAULT - log_leGatewayLoggervel = "info" + log_level = "info" self.handler = None logdir_ok = False diff --git a/tests/test_cli.py b/tests/test_cli.py index 14f65b84..8c07126a 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -43,10 +43,12 @@ def gateway(config): addr = config.get("gateway", "addr") port = config.getint("gateway", "port") + config.config["gateway"]["log_level"] = "debug" with GatewayServer(config) as gateway: # Start gateway + gateway.gw_logger_object.set_log_level("debug") gateway.set_group_id(0) gateway.serve() @@ -156,6 +158,10 @@ def test_create_subsystem(self, caplog, gateway): cli(["subsystem", "add", "--subsystem", "nqn.2016-06.io.-spdk:cnode1"]) assert f"reverse domain is not formatted correctly" in caplog.text caplog.clear() + cli(["subsystem", "add", "--subsystem", f"{subsystem}_X"]) + assert f"Invalid NQN" in caplog.text + assert f"contains invalid characters" in caplog.text + caplog.clear() cli(["subsystem", "add", "--subsystem", subsystem, "--max-namespaces", "2049"]) assert f"create_subsystem {subsystem}: True" in caplog.text cli(["--format", "json", "subsystem", "list"]) @@ -554,6 +560,14 @@ def test_add_host_invalid_nqn(self, caplog): caplog.clear() cli(["host", "add", "--subsystem", subsystem, "--host", "nqn.2X16-06.io.spdk:host1"]) assert f"invalid date code" in caplog.text + caplog.clear() + cli(["host", "add", "--subsystem", subsystem, "--host", "nqn.2016-06.io.spdk:host1_X"]) + assert f"Invalid host NQN" in caplog.text + assert f"contains invalid characters" in caplog.text + caplog.clear() + cli(["host", "add", "--subsystem", f"{subsystem}_X", "--host", "nqn.2016-06.io.spdk:host2"]) + assert f"Invalid subsystem NQN" in caplog.text + assert f"contains invalid characters" in caplog.text @pytest.mark.parametrize("listener", listener_list) def test_create_listener(self, caplog, listener, gateway): @@ -686,9 +700,12 @@ def test_remove_namespace(self, caplog, gateway): bdev_found = False bdev_list = rpc_bdev.bdev_get_bdevs(gw.spdk_rpc_client) for b in bdev_list: - if bdev_name == b["name"]: - bdev_found = True - break + try: + if bdev_name == b["name"]: + bdev_found = True + break + except KeyError: + print(f"Couldn't find field name in: {b}") assert bdev_found caplog.clear() del_ns_req = pb2.namespace_delete_req(subsystem_nqn=subsystem) @@ -701,9 +718,12 @@ def test_remove_namespace(self, caplog, gateway): bdev_found = False bdev_list = rpc_bdev.bdev_get_bdevs(gw.spdk_rpc_client) for b in bdev_list: - if bdev_name == b["name"]: - bdev_found = True - break + try: + if bdev_name == b["name"]: + bdev_found = True + break + except KeyError: + print(f"Couldn't find field name in: {b}") assert not bdev_found caplog.clear() cli(["namespace", "del", "--subsystem", subsystem, "--nsid", "2"]) diff --git a/tests/test_log_files.py b/tests/test_log_files.py index 1b8f0cb9..b060a145 100644 --- a/tests/test_log_files.py +++ b/tests/test_log_files.py @@ -35,7 +35,7 @@ def gateway(config, request): config.config["gateway"]["log_files_enabled"] = "True" config.config["gateway"]["max_log_file_size_in_mb"] = "10" config.config["gateway"]["log_files_rotation_enabled"] = "True" - config.config["gateway"]["name"] = request.node.name + config.config["gateway"]["name"] = request.node.name.replace("_", "-") if request.node.name == "test_log_files_disabled": config.config["gateway"]["log_files_enabled"] = "False" elif request.node.name == "test_log_files_rotation":