diff --git a/api/filtering/db_custom_filters.py b/api/filtering/db_custom_filters.py index 5affff2ace..25ca0f4231 100644 --- a/api/filtering/db_custom_filters.py +++ b/api/filtering/db_custom_filters.py @@ -2,6 +2,7 @@ from sqlalchemy import Integer from sqlalchemy import and_ +from sqlalchemy import func from sqlalchemy import or_ from sqlalchemy.sql.expression import ColumnElement from sqlalchemy.sql.expression import ColumnOperators @@ -116,8 +117,13 @@ def separate_operating_system_filters(filter_param) -> list[OsComparison]: return [OsComparison(comparator=filter_param)] # filter_param is a dict - for os_name in filter_param.keys(): # this doesn't account for "os_name" instead of os names - if os_name not in (os_names := _get_valid_os_names()): + for os_name in filter_param.keys(): + if os_name == "name": + ((comparator, real_os_name),) = filter_param["name"].items() + # remember to raise validation exception if name is wrong + return [OsComparison(real_os_name, comparator, None)] + + elif os_name not in (os_names := _get_valid_os_names()): raise ValidationException(f"operating_system filter only supports these OS names: {os_names}.") if not isinstance(version_node := filter_param[os_name]["version"], dict): @@ -149,16 +155,6 @@ def build_operating_system_filter(filter_param: dict) -> tuple: os_range_filter_list = [] # Contains the OS filters that use range operations os_field = Host.system_profile_facts["operating_system"] - # if isinstance(filter_param["operating_system"], dict) and "name" in filter_param["operating_system"].keys(): - # for os in filter_param["operating_system"]["name"].values(): - # print(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>> found name", os_field["name"].astext == os) - # # os_filter_list.append(and_([os_field["name"].astext == os])) - - # return os_filter_list - - if "name" in filter_param.keys(): - os_filter_list.append() - separated_filters = separate_operating_system_filters(filter_param["operating_system"]) for comparison in separated_filters: @@ -168,13 +164,14 @@ def build_operating_system_filter(filter_param: dict) -> tuple: # Uses the comparator with None, resulting in either is_(None) or is_not(None) os_filter_list.append(os_field.astext.operate(comparator, None)) - elif comparison.comparator == "eq": - print("~~~~~~~~~~~~~~~~~~~``", os_field["name"].astext == comparison.name) + elif comparison.comparator in ["eq", "neq"]: os_filters = [ - os_field["name"].astext == comparison.name, - os_field["major"].astext.cast(Integer) == comparison.major, + func.lower(os_field["name"].astext).operate(comparator, comparison.name.lower()), ] + if comparison.major is not None: + os_filters.append(os_field["major"].astext.cast(Integer) == comparison.major) + if comparison.minor: os_filters.append(os_field["minor"].astext.cast(Integer) == comparison.minor) diff --git a/tests/test_api_hosts_get.py b/tests/test_api_hosts_get.py index 0052d8ddfd..13581deeb0 100644 --- a/tests/test_api_hosts_get.py +++ b/tests/test_api_hosts_get.py @@ -1396,14 +1396,17 @@ def test_query_all_sp_filters_operating_system(db_create_host, api_get, sp_filte @pytest.mark.parametrize( - "sp_filter_param", + "sp_filter_param,match", ( - "[name][eq]=CentOS", - "[name][eq]=centos", - "[name][eq]=CENTOS", + ("[name][eq]=CentOS", True), + ("[name][eq]=centos", True), + ("[name][eq]=CENTOS", True), + ("[name][eq]=centos&filter[system_profile][operating_system][RHEL][version][eq][]=8", True), + ("[name][neq]=CENTOS", False), + ("[name][neq]=CentOS", False), ), ) -def test_query_sp_filters_operating_system_name(db_create_host, api_get, sp_filter_param): +def test_query_sp_filters_operating_system_name(db_create_host, api_get, sp_filter_param, match): # Create host with this OS match_sp_data = { "system_profile_facts": { @@ -1437,8 +1440,13 @@ def test_query_sp_filters_operating_system_name(db_create_host, api_get, sp_filt # Assert that only the matching host is returned response_ids = [result["id"] for result in response_data["results"]] - assert match_host_id in response_ids - assert nomatch_host_id not in response_ids + + if match: + assert match_host_id in response_ids + assert nomatch_host_id not in response_ids + else: + assert nomatch_host_id in response_ids + assert match_host_id not in response_ids @pytest.mark.parametrize(