diff --git a/annet/adapters/netbox/common/query.py b/annet/adapters/netbox/common/query.py index 7a9c2c3c..5b8d8f9d 100644 --- a/annet/adapters/netbox/common/query.py +++ b/annet/adapters/netbox/common/query.py @@ -1,8 +1,19 @@ +from collections import defaultdict from dataclasses import dataclass -from typing import List, Union, Iterable, Optional +from typing import cast, List, Union, Iterable, Optional, TypedDict from annet.storage import Query +FIELD_VALUE_SEPARATOR = ":" +ALLOWED_GLOB_GROUPS = ["site", "tag", "role"] + + +class Filter(TypedDict, total=False): + site: list[str] + tag: list[str] + role: list[str] + name: list[str] + @dataclass class NetboxQuery(Query): @@ -22,5 +33,29 @@ def globs(self): # We process every query host as a glob return self.query + def parse_query(self) -> Filter: + query_groups = defaultdict(list) + for q in self.globs: + if FIELD_VALUE_SEPARATOR in q: + glob_type, param = q.split(FIELD_VALUE_SEPARATOR, 2) + if glob_type not in ALLOWED_GLOB_GROUPS: + raise Exception(f"unknown query type: '{glob_type}'") + if not param: + raise Exception(f"empty param for '{glob_type}'") + query_groups[glob_type].append(param) + else: + query_groups["name"].append(q) + + query_groups.default_factory = None + return cast(Filter, query_groups) + def is_empty(self) -> bool: return len(self.query) == 0 + + def is_host_query(self) -> bool: + if not self.globs: + return False + for q in self.globs: + if FIELD_VALUE_SEPARATOR in q: + return False + return True diff --git a/annet/adapters/netbox/v37/storage.py b/annet/adapters/netbox/v37/storage.py index 6523fa79..a29be988 100644 --- a/annet/adapters/netbox/v37/storage.py +++ b/annet/adapters/netbox/v37/storage.py @@ -1,8 +1,8 @@ -from logging import getLogger -from typing import Any, Optional, List, Union, Dict -from ipaddress import ip_interface -from collections import defaultdict import ssl +from collections import defaultdict +from ipaddress import ip_interface +from logging import getLogger +from typing import Any, Optional, List, Union, Dict, cast from adaptix import P from adaptix.conversion import impl_converter, link, link_constant @@ -13,7 +13,7 @@ from annet.adapters.netbox.common.manufacturer import ( get_hw, get_breed, ) -from annet.adapters.netbox.common.query import NetboxQuery +from annet.adapters.netbox.common.query import NetboxQuery, FIELD_VALUE_SEPARATOR from annet.adapters.netbox.common.storage_opts import NetboxStorageOpts from annet.annlib.netdev.views.hardware import HardwareView from annet.storage import Storage, Device, Interface @@ -101,6 +101,9 @@ def __init__(self, opts: Optional[NetboxStorageOpts] = None): self.exact_host_filter = opts.exact_host_filter self.netbox = NetboxV37(url=url, token=token, ssl_context=ctx) self._all_fqdns: Optional[list[str]] = None + self._id_devices: dict[int, models.NetboxDevice] = {} + self._name_devices: dict[str, models.NetboxDevice] = {} + self._short_name_devices: dict[str, models.NetboxDevice] = {} def __enter__(self): return self @@ -136,6 +139,37 @@ def make_devices( ) -> List[models.NetboxDevice]: if isinstance(query, list): query = NetboxQuery.new(query) + + devices = [] + if query.is_host_query(): + globs = [] + for glob in query.globs: + if glob in self._name_devices: + devices.append(self._name_devices[glob]) + if glob in self._short_name_devices: + devices.append(self._short_name_devices[glob]) + else: + globs.append(glob) + if not globs: + return devices + query = NetboxQuery.new(globs) + + return devices + self._make_devices( + query=query, + preload_neighbors=preload_neighbors, + use_mesh=use_mesh, + preload_extra_fields=preload_extra_fields, + **kwargs + ) + + def _make_devices( + self, + query: NetboxQuery, + preload_neighbors=False, + use_mesh=None, + preload_extra_fields=False, + **kwargs, + ) -> List[models.NetboxDevice]: device_ids = { device.id: extend_device( device=device, @@ -148,6 +182,9 @@ def make_devices( if not device_ids: return [] + for device in device_ids.values(): + self._record_device(device) + interfaces = self._load_interfaces(list(device_ids)) neighbours = {x.id: x for x in self._load_neighbours(interfaces)} neighbours_seen: dict[str, set] = defaultdict(set) @@ -162,10 +199,17 @@ def make_devices( return list(device_ids.values()) + def _record_device(self, device: models.NetboxDevice): + self._id_devices[device.id] = device + self._short_name_devices[device.name] = device + if not self.exact_host_filter: + short_name = device.name.split(".")[0] + self._short_name_devices[short_name] = device + def _load_devices(self, query: NetboxQuery) -> List[api_models.Device]: if not query.globs: return [] - query_groups = parse_glob(self.exact_host_filter, query.globs) + query_groups = parse_glob(self.exact_host_filter, query) return [ device for device in self.netbox.dcim_all_devices(**query_groups).results @@ -221,6 +265,9 @@ def get_device( self, obj_id, preload_neighbors=False, use_mesh=None, **kwargs, ) -> models.NetboxDevice: + if obj_id in self._id_devices: + return self._id_devices[obj_id] + device = self.netbox.dcim_device(obj_id) interfaces = self._load_interfaces([device.id]) neighbours = self._load_neighbours(interfaces) @@ -231,6 +278,7 @@ def get_device( interfaces=interfaces, neighbours=neighbours, ) + self._record_device(res) return res def flush_perf(self): @@ -261,7 +309,7 @@ def _match_query(exact_host_filter: bool, query: NetboxQuery, device_data: api_m """ if exact_host_filter: return True # nothing to check, all filtering is done by netbox - hostnames = [subquery.strip() for subquery in query.globs if ":" not in subquery] + hostnames = [subquery.strip() for subquery in query.globs if FIELD_VALUE_SEPARATOR not in subquery] if not hostnames: return True # no hostnames to check short_name = device_data.name.split(".")[0] @@ -292,24 +340,11 @@ def add_dot(raw_query: Any) -> Any: return raw_query -ALLOWED_GLOB_GROUPS = ["site", "tag", "role"] - - -def parse_glob(exact_host_filter: bool, globs: list[str]) -> dict[str, list[str]]: - query_groups: defaultdict[str, list[str]] = defaultdict(list) - for q in globs: - if ":" in q: - glob_type, param = q.split(":", 2) - if glob_type not in ALLOWED_GLOB_GROUPS: - raise Exception(f"unknown query type: '{glob_type}'") - if not param: - raise Exception(f"empty param for '{glob_type}'") - query_groups[glob_type].append(param) +def parse_glob(exact_host_filter: bool, query: NetboxQuery) -> dict[str, list[str]]: + query_groups = cast(dict[str, list[str]], query.parse_query()) + if names := query_groups.pop("name", None): + if exact_host_filter: + query_groups["name__ie"] = names else: - if exact_host_filter: - query_groups["name__ie"].append(q) - else: - query_groups["name__ic"].append(_hostname_dot_hack(q)) - - query_groups.default_factory = None + query_groups["name__ic"] = [_hostname_dot_hack(name) for name in names] return query_groups diff --git a/tests/annet/test_netbox.py b/tests/annet/test_netbox.py index 73802f0e..27b3804f 100644 --- a/tests/annet/test_netbox.py +++ b/tests/annet/test_netbox.py @@ -1,13 +1,14 @@ +from annet.adapters.netbox.common.query import NetboxQuery from annet.adapters.netbox.v37.storage import parse_glob import pytest def test_parse_glob(): - assert parse_glob(True, ["host"]) == {"name__ie": ["host"]} - assert parse_glob(False, ["host"]) == {"name__ic": ["host."]} - assert parse_glob(True, ["site:mysite"]) == {"site": ["mysite"]} - assert parse_glob(True, ["tag:mysite", "justhost"]) == {"name__ie": ["justhost"], "tag": ["mysite"]} + assert parse_glob(True, NetboxQuery(["host"])) == {"name__ie": ["host"]} + assert parse_glob(False,NetboxQuery(["host"])) == {"name__ic": ["host."]} + assert parse_glob(True, NetboxQuery(["site:mysite"])) == {"site": ["mysite"]} + assert parse_glob(True, NetboxQuery(["tag:mysite", "justhost"])) == {"name__ie": ["justhost"], "tag": ["mysite"]} with pytest.raises(Exception): - parse_glob(True, ["host:"]) + parse_glob(True, NetboxQuery(["host:"])) with pytest.raises(Exception): - parse_glob(True, ["NONONO:param"]) + parse_glob(True, NetboxQuery(["NONONO:param"]))