Skip to content

Commit

Permalink
refactor OsComparison
Browse files Browse the repository at this point in the history
  • Loading branch information
FabriciaDinizRH committed Jan 10, 2025
1 parent 73d0fc2 commit 1adf118
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 54 deletions.
119 changes: 67 additions & 52 deletions api/filtering/db_custom_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from api.filtering.filtering_common import POSTGRES_COMPARATOR_LOOKUP
from api.filtering.filtering_common import POSTGRES_COMPARATOR_NO_EQ_LOOKUP
from api.filtering.filtering_common import POSTGRES_DEFAULT_COMPARATOR
from api.filtering.filtering_common import get_valid_os_names
from app import system_profile_spec
from app.config import HOST_TYPES
from app.exceptions import ValidationException
Expand All @@ -23,8 +24,27 @@

# Utility class to facilitate OS filter comparison
# The list of comparators can be seen in POSTGRES_COMPARATOR_LOOKUP
class OsComparison:
def __init__(self, name="", comparator="", major=0, minor=None):
class OsFilter:
def __init__(self, name="", comparator="", version=None):
if name and name.lower() not in (os_names := [name.lower() for name in get_valid_os_names()]):
raise ValidationException(f"operating_system filter only supports these OS names: {os_names}.")

if version is None:
major, minor = None, None
else:
version_split = version.split(".")

if len(version_split) > 2:
raise ValidationException("operating_system filter can only have a major and minor version.")
elif len(version_split) == 1: # only major version was sent
major = version_split[0]
minor = None
else:
major, minor = version_split

if not major.isdigit() or (minor and not minor.isdigit()):
raise ValidationException("operating_system major and minor versions must be numerical.")

self.name = name
self.comparator = comparator
self.major = major
Expand Down Expand Up @@ -94,60 +114,42 @@ def _get_field_filter_for_deepest_param(sp_spec: dict, filter: dict, parent_node
return sp_spec[key]["filter"]


def _get_valid_os_names() -> list:
return system_profile_spec()["operating_system"]["children"]["name"]["enum"]


# Extracts specific filters from the filter param object and puts them in an easier format
# For instance, {'RHEL': {'version': {'lt': '9.0', 'gt': '8.5'}}} becomes:
# [
# OsComparison{name: 'RHEL', comparator: 'lt', major: '9', minor: '0'}
# OsComparison{name: 'RHEL', comparator: 'gt', major: '8', minor: '5'}
# OsFilter{name: 'RHEL', major: '9', minor: '0', comparator: 'lt'}
# OsFilter{name: 'RHEL', major: '8', minor: '5', comparator: 'gt'}
# ]
# Has a similar purpose to _unique_paths, but the OS filter works a bit differently.
def separate_operating_system_filters(filter_param) -> list[OsComparison]:
def separate_operating_system_filters(filter_url_params) -> list[OsFilter]:
os_filter_list = []

# Handle filter_param if a list is passed in
if isinstance(filter_param, list):
return [OsComparison(comparator=param) for param in filter_param]
# Handle filter_url_params if a list is passed in
if isinstance(filter_url_params, list):
return [OsFilter(comparator=param) for param in filter_url_params]

# Handle filter_param if a str is passed in
elif isinstance(filter_param, str):
return [OsComparison(comparator=filter_param)]
# Handle filter_url_params if a str is passed in
elif isinstance(filter_url_params, str):
return [OsFilter(comparator=filter_url_params)]

# filter_param is a dict
for filter_key in filter_param.keys():
# filter_url_params is a dict
for filter_key in filter_url_params.keys():
if filter_key == "name":
((os_comparator, os_name),) = filter_param["name"].items()
((os_comparator, os_name),) = filter_url_params["name"].items()
version_node = {os_comparator: [None]}
else:
os_name = filter_key
if not isinstance(version_node := filter_param[os_name]["version"], dict):
if not isinstance(version_node := filter_url_params[os_name]["version"], dict):
# If there's no comparator, treat it as "eq"
version_node = {"eq": version_node}

if os_name.lower() not in (os_names := [name.lower() for name in _get_valid_os_names()]):
raise ValidationException(f"operating_system filter only supports these OS names: {os_names}.")

for os_comparator in version_node.keys():
version_array = version_node[os_comparator]
if not isinstance(version_array, list):
version_array = [version_array]

for version in version_array:
if version is None:
version_split = [None, None]
else:
version_split = version.split(".")
if len(version_split) > 2:
raise ValidationException("operating_system filter can only have a major and minor version.")

for v in version_split:
if not v.isdigit():
raise ValidationException("operating_system major and minor versions must be numerical.")

os_filter_list.append(OsComparison(os_name, os_comparator, *version_split))
os_filter_list.append(OsFilter(os_name, os_comparator, version))

return os_filter_list

Expand All @@ -160,46 +162,46 @@ def build_operating_system_filter(filter_param: dict) -> tuple:

separated_filters = separate_operating_system_filters(filter_param["operating_system"])

for comparison in separated_filters:
comparator = POSTGRES_COMPARATOR_LOOKUP.get(comparison.comparator)
for os_filter in separated_filters:
comparator = POSTGRES_COMPARATOR_LOOKUP.get(os_filter.comparator)

if comparison.comparator in ["nil", "not_nil"]:
if os_filter.comparator in ["nil", "not_nil"]:
# Uses the comparator with None, resulting in either is_(None) or is_not(None)
os_filter_list.append(os_field.astext.operate(comparator, None))

elif comparison.comparator in ["eq", "neq"]:
elif os_filter.comparator in ["eq", "neq"]:
os_filters = [
func.lower(os_field["name"].astext).operate(comparator, comparison.name.lower()),
func.lower(os_field["name"].astext).operate(comparator, os_filter.name.lower()),
]

if comparison.major is not None:
os_filters.append(os_field["major"].astext.cast(Integer) == comparison.major)
if os_filter.major is not None:
os_filters.append(os_field["major"].astext.cast(Integer) == os_filter.major)

if comparison.minor:
os_filters.append(os_field["minor"].astext.cast(Integer) == comparison.minor)
if os_filter.minor:
os_filters.append(os_field["minor"].astext.cast(Integer) == os_filter.minor)

os_filter_list.append(and_(*os_filters))
else:
if comparison.minor is not None:
# If the minor version is specified, the comparison logic is a bit more complex. For instance:
if os_filter.minor is not None:
# If the minor version is specified, the os_filter logic is a bit more complex. For instance:
# input: version <= 9.5
# output: (major < 9) OR (major = 9 AND minor <= 5)
comparator_no_eq = POSTGRES_COMPARATOR_NO_EQ_LOOKUP.get(comparison.comparator)
comparator_no_eq = POSTGRES_COMPARATOR_NO_EQ_LOOKUP.get(os_filter.comparator)
os_filter = and_(
os_field["name"].astext == comparison.name,
os_field["name"].astext == os_filter.name,
or_(
os_field["major"].astext.cast(Integer).operate(comparator_no_eq, comparison.major),
os_field["major"].astext.cast(Integer).operate(comparator_no_eq, os_filter.major),
and_(
os_field["major"].astext.cast(Integer) == comparison.major,
os_field["minor"].astext.cast(Integer).operate(comparator, comparison.minor),
os_field["major"].astext.cast(Integer) == os_filter.major,
os_field["minor"].astext.cast(Integer).operate(comparator, os_filter.minor),
),
),
)

else:
os_filter = and_(
os_field["name"].astext == comparison.name,
os_field["major"].astext.cast(Integer).operate(comparator, comparison.major),
os_field["name"].astext == os_filter.name,
os_field["major"].astext.cast(Integer).operate(comparator, os_filter.major),
)

# Add to AND filter
Expand Down Expand Up @@ -373,3 +375,16 @@ def build_system_profile_filter(system_profile_param: dict) -> tuple:
system_profile_filter += (filter,)

return system_profile_filter


def get_major_minor_from_version(version_split: list[str]):
if len(version_split) > 2:
raise ValidationException("operating_system filter can only have a major and minor version.")

if not [v.isdigit() for v in version_split]:
raise ValidationException("operating_system major and minor versions must be numerical.")

major = version_split.pop(0)
minor = version_split[0] if version_split else None

return major, minor
6 changes: 6 additions & 0 deletions api/filtering/filtering_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from sqlalchemy import Boolean
from sqlalchemy.sql.expression import ColumnOperators

from app import system_profile_spec

# Converts our filter param comparison operators into their SQL equivalents.
POSTGRES_COMPARATOR_LOOKUP = {
"lt": ColumnOperators.__lt__,
Expand Down Expand Up @@ -43,3 +45,7 @@
"integer": lambda v: int(v),
"boolean": lambda v: str.lower(v) == "true",
}


def get_valid_os_names() -> list:
return system_profile_spec()["operating_system"]["children"]["name"]["enum"]
4 changes: 2 additions & 2 deletions tests/test_api_hosts_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -1629,8 +1629,8 @@ def test_query_all_sp_filters_invalid_value(api_get, sp_filter_param):
"sp_filter_param",
(
"[operating_system][foo][version]=8.1", # Invalid OS name
"[operating_system][name][eq]=rhelz", # Invalid OS name
"[operating_system][RHEL][version]=bar", # Invalid OS version
"[operating_system][name][eq]=rhelz", # Invalid OS name
"[operating_system][RHEL][version]=bar", # Invalid OS version
),
)
def test_query_all_sp_filters_invalid_operating_system(api_get, sp_filter_param):
Expand Down

0 comments on commit 1adf118

Please sign in to comment.