From 1adf118411cd1d36adfb43bf27680b0d46abdc51 Mon Sep 17 00:00:00 2001 From: Fabricia Diniz Date: Fri, 10 Jan 2025 16:29:08 +0000 Subject: [PATCH] refactor OsComparison --- api/filtering/db_custom_filters.py | 119 ++++++++++++++++------------- api/filtering/filtering_common.py | 6 ++ tests/test_api_hosts_get.py | 4 +- 3 files changed, 75 insertions(+), 54 deletions(-) diff --git a/api/filtering/db_custom_filters.py b/api/filtering/db_custom_filters.py index 1223fe3ef..3571471ee 100644 --- a/api/filtering/db_custom_filters.py +++ b/api/filtering/db_custom_filters.py @@ -12,6 +12,7 @@ from api.filtering.filtering_common import POSTGRES_COMPARATOR_LOOKUP from api.filtering.filtering_common import POSTGRES_COMPARATOR_NO_EQ_LOOKUP from api.filtering.filtering_common import POSTGRES_DEFAULT_COMPARATOR +from api.filtering.filtering_common import get_valid_os_names from app import system_profile_spec from app.config import HOST_TYPES from app.exceptions import ValidationException @@ -23,8 +24,27 @@ # Utility class to facilitate OS filter comparison # The list of comparators can be seen in POSTGRES_COMPARATOR_LOOKUP -class OsComparison: - def __init__(self, name="", comparator="", major=0, minor=None): +class OsFilter: + def __init__(self, name="", comparator="", version=None): + if name and name.lower() not in (os_names := [name.lower() for name in get_valid_os_names()]): + raise ValidationException(f"operating_system filter only supports these OS names: {os_names}.") + + if version is None: + major, minor = None, None + else: + version_split = version.split(".") + + if len(version_split) > 2: + raise ValidationException("operating_system filter can only have a major and minor version.") + elif len(version_split) == 1: # only major version was sent + major = version_split[0] + minor = None + else: + major, minor = version_split + + if not major.isdigit() or (minor and not minor.isdigit()): + raise ValidationException("operating_system major and minor versions must be numerical.") + self.name = name self.comparator = comparator self.major = major @@ -94,60 +114,42 @@ def _get_field_filter_for_deepest_param(sp_spec: dict, filter: dict, parent_node return sp_spec[key]["filter"] -def _get_valid_os_names() -> list: - return system_profile_spec()["operating_system"]["children"]["name"]["enum"] - - # Extracts specific filters from the filter param object and puts them in an easier format # For instance, {'RHEL': {'version': {'lt': '9.0', 'gt': '8.5'}}} becomes: # [ -# OsComparison{name: 'RHEL', comparator: 'lt', major: '9', minor: '0'} -# OsComparison{name: 'RHEL', comparator: 'gt', major: '8', minor: '5'} +# OsFilter{name: 'RHEL', major: '9', minor: '0', comparator: 'lt'} +# OsFilter{name: 'RHEL', major: '8', minor: '5', comparator: 'gt'} # ] # Has a similar purpose to _unique_paths, but the OS filter works a bit differently. -def separate_operating_system_filters(filter_param) -> list[OsComparison]: +def separate_operating_system_filters(filter_url_params) -> list[OsFilter]: os_filter_list = [] - # Handle filter_param if a list is passed in - if isinstance(filter_param, list): - return [OsComparison(comparator=param) for param in filter_param] + # Handle filter_url_params if a list is passed in + if isinstance(filter_url_params, list): + return [OsFilter(comparator=param) for param in filter_url_params] - # Handle filter_param if a str is passed in - elif isinstance(filter_param, str): - return [OsComparison(comparator=filter_param)] + # Handle filter_url_params if a str is passed in + elif isinstance(filter_url_params, str): + return [OsFilter(comparator=filter_url_params)] - # filter_param is a dict - for filter_key in filter_param.keys(): + # filter_url_params is a dict + for filter_key in filter_url_params.keys(): if filter_key == "name": - ((os_comparator, os_name),) = filter_param["name"].items() + ((os_comparator, os_name),) = filter_url_params["name"].items() version_node = {os_comparator: [None]} else: os_name = filter_key - if not isinstance(version_node := filter_param[os_name]["version"], dict): + if not isinstance(version_node := filter_url_params[os_name]["version"], dict): # If there's no comparator, treat it as "eq" version_node = {"eq": version_node} - if os_name.lower() not in (os_names := [name.lower() for name in _get_valid_os_names()]): - raise ValidationException(f"operating_system filter only supports these OS names: {os_names}.") - for os_comparator in version_node.keys(): version_array = version_node[os_comparator] if not isinstance(version_array, list): version_array = [version_array] for version in version_array: - if version is None: - version_split = [None, None] - else: - version_split = version.split(".") - if len(version_split) > 2: - raise ValidationException("operating_system filter can only have a major and minor version.") - - for v in version_split: - if not v.isdigit(): - raise ValidationException("operating_system major and minor versions must be numerical.") - - os_filter_list.append(OsComparison(os_name, os_comparator, *version_split)) + os_filter_list.append(OsFilter(os_name, os_comparator, version)) return os_filter_list @@ -160,46 +162,46 @@ def build_operating_system_filter(filter_param: dict) -> tuple: separated_filters = separate_operating_system_filters(filter_param["operating_system"]) - for comparison in separated_filters: - comparator = POSTGRES_COMPARATOR_LOOKUP.get(comparison.comparator) + for os_filter in separated_filters: + comparator = POSTGRES_COMPARATOR_LOOKUP.get(os_filter.comparator) - if comparison.comparator in ["nil", "not_nil"]: + if os_filter.comparator in ["nil", "not_nil"]: # Uses the comparator with None, resulting in either is_(None) or is_not(None) os_filter_list.append(os_field.astext.operate(comparator, None)) - elif comparison.comparator in ["eq", "neq"]: + elif os_filter.comparator in ["eq", "neq"]: os_filters = [ - func.lower(os_field["name"].astext).operate(comparator, comparison.name.lower()), + func.lower(os_field["name"].astext).operate(comparator, os_filter.name.lower()), ] - if comparison.major is not None: - os_filters.append(os_field["major"].astext.cast(Integer) == comparison.major) + if os_filter.major is not None: + os_filters.append(os_field["major"].astext.cast(Integer) == os_filter.major) - if comparison.minor: - os_filters.append(os_field["minor"].astext.cast(Integer) == comparison.minor) + if os_filter.minor: + os_filters.append(os_field["minor"].astext.cast(Integer) == os_filter.minor) os_filter_list.append(and_(*os_filters)) else: - if comparison.minor is not None: - # If the minor version is specified, the comparison logic is a bit more complex. For instance: + if os_filter.minor is not None: + # If the minor version is specified, the os_filter logic is a bit more complex. For instance: # input: version <= 9.5 # output: (major < 9) OR (major = 9 AND minor <= 5) - comparator_no_eq = POSTGRES_COMPARATOR_NO_EQ_LOOKUP.get(comparison.comparator) + comparator_no_eq = POSTGRES_COMPARATOR_NO_EQ_LOOKUP.get(os_filter.comparator) os_filter = and_( - os_field["name"].astext == comparison.name, + os_field["name"].astext == os_filter.name, or_( - os_field["major"].astext.cast(Integer).operate(comparator_no_eq, comparison.major), + os_field["major"].astext.cast(Integer).operate(comparator_no_eq, os_filter.major), and_( - os_field["major"].astext.cast(Integer) == comparison.major, - os_field["minor"].astext.cast(Integer).operate(comparator, comparison.minor), + os_field["major"].astext.cast(Integer) == os_filter.major, + os_field["minor"].astext.cast(Integer).operate(comparator, os_filter.minor), ), ), ) else: os_filter = and_( - os_field["name"].astext == comparison.name, - os_field["major"].astext.cast(Integer).operate(comparator, comparison.major), + os_field["name"].astext == os_filter.name, + os_field["major"].astext.cast(Integer).operate(comparator, os_filter.major), ) # Add to AND filter @@ -373,3 +375,16 @@ def build_system_profile_filter(system_profile_param: dict) -> tuple: system_profile_filter += (filter,) return system_profile_filter + + +def get_major_minor_from_version(version_split: list[str]): + if len(version_split) > 2: + raise ValidationException("operating_system filter can only have a major and minor version.") + + if not [v.isdigit() for v in version_split]: + raise ValidationException("operating_system major and minor versions must be numerical.") + + major = version_split.pop(0) + minor = version_split[0] if version_split else None + + return major, minor diff --git a/api/filtering/filtering_common.py b/api/filtering/filtering_common.py index deeb67890..4eca438f0 100644 --- a/api/filtering/filtering_common.py +++ b/api/filtering/filtering_common.py @@ -6,6 +6,8 @@ from sqlalchemy import Boolean from sqlalchemy.sql.expression import ColumnOperators +from app import system_profile_spec + # Converts our filter param comparison operators into their SQL equivalents. POSTGRES_COMPARATOR_LOOKUP = { "lt": ColumnOperators.__lt__, @@ -43,3 +45,7 @@ "integer": lambda v: int(v), "boolean": lambda v: str.lower(v) == "true", } + + +def get_valid_os_names() -> list: + return system_profile_spec()["operating_system"]["children"]["name"]["enum"] diff --git a/tests/test_api_hosts_get.py b/tests/test_api_hosts_get.py index 3d9143c45..fb45d3c62 100644 --- a/tests/test_api_hosts_get.py +++ b/tests/test_api_hosts_get.py @@ -1629,8 +1629,8 @@ def test_query_all_sp_filters_invalid_value(api_get, sp_filter_param): "sp_filter_param", ( "[operating_system][foo][version]=8.1", # Invalid OS name - "[operating_system][name][eq]=rhelz", # Invalid OS name - "[operating_system][RHEL][version]=bar", # Invalid OS version + "[operating_system][name][eq]=rhelz", # Invalid OS name + "[operating_system][RHEL][version]=bar", # Invalid OS version ), ) def test_query_all_sp_filters_invalid_operating_system(api_get, sp_filter_param):