Skip to content

Commit

Permalink
Merge pull request #475 from gbregman/devel
Browse files Browse the repository at this point in the history
Make sure not to include "_" in OMAP key components
  • Loading branch information
gbregman authored Mar 5, 2024
2 parents bfea178 + 0f01098 commit dfb7bdf
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 16 deletions.
21 changes: 19 additions & 2 deletions control/grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ def __init__(self, config: GatewayConfig, gateway_state: GatewayStateHandler, om
self.gateway_name = self.config.get("gateway", "name")
if not self.gateway_name:
self.gateway_name = socket.gethostname()
if not GatewayState.is_key_element_valid(self.gateway_name):
raise ValueError(f"Gateway name \"{self.gateway_name}\" contains invalid characters")
self.gateway_group = self.config.get("gateway", "group")
self.verify_nqns = self.config.getboolean_with_default("gateway", "verify_nqns", True)
self.ana_map = defaultdict(dict)
Expand Down Expand Up @@ -458,6 +460,11 @@ def create_subsystem_safe(self, request, context):
f"Received request to create subsystem {request.subsystem_nqn}, enable_ha: {request.enable_ha}, context: {context}")

errmsg = ""
if not GatewayState.is_key_element_valid(request.subsystem_nqn):
errmsg = f"{create_subsystem_error_prefix}: Invalid NQN \"{request.subsystem_nqn}\", contains invalid characters"
self.logger.error(f"{errmsg}")
return pb2.req_status(status = errno.EINVAL, error_message = errmsg)

if self.verify_nqns:
rc = GatewayUtils.is_valid_nqn(request.subsystem_nqn)
if rc[0] != 0:
Expand Down Expand Up @@ -762,7 +769,7 @@ def set_ana_state_safe(self, ana_info: pb2.ana_info, context=None):
nqn = nas.nqn
if not self.get_subsystem_ha_status(nqn):
continue
prefix = f"{self.gateway_state.local.LISTENER_PREFIX}{nqn}{GatewayState.OMAP_KEY_DELIMITER}{self.gateway_name}{GatewayState.OMAP_KEY_DELIMITER}"
prefix = GatewayState.build_partial_listener_key(nqn, self.gateway_name) + GatewayState.OMAP_KEY_DELIMITER
listener_keys = [key for key in state.keys() if key.startswith(prefix)]
self.logger.info(f"Iterate over {nqn=} {prefix=} {listener_keys=}")
# fill the static gateway dictionary per nqn and grp_id
Expand Down Expand Up @@ -1563,6 +1570,16 @@ def add_host_safe(self, request, context):
all_host_failure_prefix=f"Failure allowing open host access to {request.subsystem_nqn}"
host_failure_prefix=f"Failure adding host {request.host_nqn} to {request.subsystem_nqn}"

if not GatewayState.is_key_element_valid(request.host_nqn):
errmsg = f"{host_failure_prefix}: Invalid host NQN \"{request.host_nqn}\", contains invalid characters"
self.logger.error(f"{errmsg}")
return pb2.req_status(status = errno.EINVAL, error_message = errmsg)

if not GatewayState.is_key_element_valid(request.subsystem_nqn):
errmsg = f"{host_failure_prefix}: Invalid subsystem NQN \"{request.subsystem_nqn}\", contains invalid characters"
self.logger.error(f"{errmsg}")
return pb2.req_status(status = errno.EINVAL, error_message = errmsg)

if self.verify_nqns:
rc = GatewayService.is_valid_host_nqn(request.host_nqn)
if rc.status != 0:
Expand Down Expand Up @@ -1929,7 +1946,7 @@ def matching_listener_exists(self, context, nqn, gw_name, traddr, trsvcid) -> bo
state = self.gateway_state.local.get_state()
# We want to check for all the listeners for this address and port, regardless of the gateway
key_prefix = GatewayState.build_partial_listener_key(nqn)
key_suffix = GatewayState.build_listener_key_suffix("", "TCP", traddr, trsvcid)
key_suffix = GatewayState.build_listener_key_suffix(None, "TCP", traddr, trsvcid)

for key, val in state.items():
if not key.startswith(key_prefix):
Expand Down
22 changes: 16 additions & 6 deletions control/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,23 @@ class GatewayState(ABC):
LISTENER_PREFIX = "listener" + OMAP_KEY_DELIMITER
NAMESPACE_QOS_PREFIX = "qos" + OMAP_KEY_DELIMITER

def is_key_element_valid(s: str) -> bool:
if type(s) != str:
return False
if GatewayState.OMAP_KEY_DELIMITER in s:
return False
return True

def build_namespace_key(subsystem_nqn: str, nsid) -> str:
key = GatewayState.NAMESPACE_PREFIX + subsystem_nqn
if nsid is not None:
key = key + GatewayState.OMAP_KEY_DELIMITER + str(nsid)
key += GatewayState.OMAP_KEY_DELIMITER + str(nsid)
return key

def build_namespace_qos_key(subsystem_nqn: str, nsid) -> str:
key = GatewayState.NAMESPACE_QOS_PREFIX + subsystem_nqn
if nsid is not None:
key = key + GatewayState.OMAP_KEY_DELIMITER + str(nsid)
key += GatewayState.OMAP_KEY_DELIMITER + str(nsid)
return key

def build_subsystem_key(subsystem_nqn: str) -> str:
Expand All @@ -49,11 +56,14 @@ def build_subsystem_key(subsystem_nqn: str) -> str:
def build_host_key(subsystem_nqn: str, host_nqn) -> str:
key = GatewayState.HOST_PREFIX + subsystem_nqn
if host_nqn is not None:
key = key + GatewayState.OMAP_KEY_DELIMITER + host_nqn
key += GatewayState.OMAP_KEY_DELIMITER + host_nqn
return key

def build_partial_listener_key(subsystem_nqn: str) -> str:
return GatewayState.LISTENER_PREFIX + subsystem_nqn
def build_partial_listener_key(subsystem_nqn: str, gateway = None) -> str:
key = GatewayState.LISTENER_PREFIX + subsystem_nqn
if gateway:
key += GatewayState.OMAP_KEY_DELIMITER + gateway
return key

def build_listener_key_suffix(gateway: str, trtype: str, traddr: str, trsvcid: int) -> str:
if gateway:
Expand All @@ -63,7 +73,7 @@ def build_listener_key_suffix(gateway: str, trtype: str, traddr: str, trsvcid: i
return GatewayState.OMAP_KEY_DELIMITER + traddr + GatewayState.OMAP_KEY_DELIMITER + str(trsvcid)

def build_listener_key(subsystem_nqn: str, gateway: str, trtype: str, traddr: str, trsvcid: int) -> str:
return GatewayState.build_partial_listener_key(subsystem_nqn) + GatewayState.build_listener_key_suffix(gateway, trtype, traddr, str(trsvcid))
return GatewayState.build_partial_listener_key(subsystem_nqn, gateway) + GatewayState.build_listener_key_suffix(None, trtype, traddr, str(trsvcid))

@abstractmethod
def get_state(self) -> Dict[str, str]:
Expand Down
2 changes: 1 addition & 1 deletion control/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def __init__(self, config=None):
log_files_rotation_enabled = False
max_log_file_size = GatewayLogger.MAX_LOG_FILE_SIZE_DEFAULT
max_log_files_count = GatewayLogger.MAX_LOG_FILES_COUNT_DEFAULT
log_leGatewayLoggervel = "info"
log_level = "info"

self.handler = None
logdir_ok = False
Expand Down
32 changes: 26 additions & 6 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,12 @@ def gateway(config):

addr = config.get("gateway", "addr")
port = config.getint("gateway", "port")
config.config["gateway"]["log_level"] = "debug"

with GatewayServer(config) as gateway:

# Start gateway
gateway.gw_logger_object.set_log_level("debug")
gateway.set_group_id(0)
gateway.serve()

Expand Down Expand Up @@ -156,6 +158,10 @@ def test_create_subsystem(self, caplog, gateway):
cli(["subsystem", "add", "--subsystem", "nqn.2016-06.io.-spdk:cnode1"])
assert f"reverse domain is not formatted correctly" in caplog.text
caplog.clear()
cli(["subsystem", "add", "--subsystem", f"{subsystem}_X"])
assert f"Invalid NQN" in caplog.text
assert f"contains invalid characters" in caplog.text
caplog.clear()
cli(["subsystem", "add", "--subsystem", subsystem, "--max-namespaces", "2049"])
assert f"create_subsystem {subsystem}: True" in caplog.text
cli(["--format", "json", "subsystem", "list"])
Expand Down Expand Up @@ -554,6 +560,14 @@ def test_add_host_invalid_nqn(self, caplog):
caplog.clear()
cli(["host", "add", "--subsystem", subsystem, "--host", "nqn.2X16-06.io.spdk:host1"])
assert f"invalid date code" in caplog.text
caplog.clear()
cli(["host", "add", "--subsystem", subsystem, "--host", "nqn.2016-06.io.spdk:host1_X"])
assert f"Invalid host NQN" in caplog.text
assert f"contains invalid characters" in caplog.text
caplog.clear()
cli(["host", "add", "--subsystem", f"{subsystem}_X", "--host", "nqn.2016-06.io.spdk:host2"])
assert f"Invalid subsystem NQN" in caplog.text
assert f"contains invalid characters" in caplog.text

@pytest.mark.parametrize("listener", listener_list)
def test_create_listener(self, caplog, listener, gateway):
Expand Down Expand Up @@ -686,9 +700,12 @@ def test_remove_namespace(self, caplog, gateway):
bdev_found = False
bdev_list = rpc_bdev.bdev_get_bdevs(gw.spdk_rpc_client)
for b in bdev_list:
if bdev_name == b["name"]:
bdev_found = True
break
try:
if bdev_name == b["name"]:
bdev_found = True
break
except KeyError:
print(f"Couldn't find field name in: {b}")
assert bdev_found
caplog.clear()
del_ns_req = pb2.namespace_delete_req(subsystem_nqn=subsystem)
Expand All @@ -701,9 +718,12 @@ def test_remove_namespace(self, caplog, gateway):
bdev_found = False
bdev_list = rpc_bdev.bdev_get_bdevs(gw.spdk_rpc_client)
for b in bdev_list:
if bdev_name == b["name"]:
bdev_found = True
break
try:
if bdev_name == b["name"]:
bdev_found = True
break
except KeyError:
print(f"Couldn't find field name in: {b}")
assert not bdev_found
caplog.clear()
cli(["namespace", "del", "--subsystem", subsystem, "--nsid", "2"])
Expand Down
2 changes: 1 addition & 1 deletion tests/test_log_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def gateway(config, request):
config.config["gateway"]["log_files_enabled"] = "True"
config.config["gateway"]["max_log_file_size_in_mb"] = "10"
config.config["gateway"]["log_files_rotation_enabled"] = "True"
config.config["gateway"]["name"] = request.node.name
config.config["gateway"]["name"] = request.node.name.replace("_", "-")
if request.node.name == "test_log_files_disabled":
config.config["gateway"]["log_files_enabled"] = "False"
elif request.node.name == "test_log_files_rotation":
Expand Down

0 comments on commit dfb7bdf

Please sign in to comment.