From 2d69f8b267377869f2286a2aae63b179c0a30716 Mon Sep 17 00:00:00 2001 From: "Josephine.Rutten" Date: Fri, 25 Oct 2024 15:33:09 +0200 Subject: [PATCH] add mypy to pre-commit and solved mypy errors --- .pre-commit-config.yaml | 18 +++ src/cnaas_nms/__init__.py | 2 +- src/cnaas_nms/api/app.py | 11 +- src/cnaas_nms/api/auth.py | 6 +- src/cnaas_nms/api/device.py | 73 ++++----- src/cnaas_nms/api/firmware.py | 30 ++-- src/cnaas_nms/api/generic.py | 11 +- src/cnaas_nms/api/groups.py | 6 +- src/cnaas_nms/api/interface.py | 17 +-- src/cnaas_nms/api/jobs.py | 19 +-- src/cnaas_nms/api/json.py | 2 +- src/cnaas_nms/api/linknet.py | 22 +-- src/cnaas_nms/api/mgmtdomain.py | 12 +- .../api/models/stackmembers_model.py | 6 +- src/cnaas_nms/api/settings.py | 2 +- src/cnaas_nms/api/tests/test_device.py | 16 +- src/cnaas_nms/app_settings.py | 8 +- src/cnaas_nms/db/device.py | 13 +- src/cnaas_nms/db/git.py | 2 +- src/cnaas_nms/db/helper.py | 122 ++++++++------- src/cnaas_nms/db/interface.py | 12 +- src/cnaas_nms/db/job.py | 55 ++++--- src/cnaas_nms/db/joblock.py | 34 ++--- src/cnaas_nms/db/linknet.py | 24 +-- src/cnaas_nms/db/mgmtdomain.py | 36 ++--- src/cnaas_nms/db/reservedip.py | 2 + src/cnaas_nms/db/session.py | 4 +- src/cnaas_nms/db/settings.py | 89 ++++++----- src/cnaas_nms/db/settings_fields.py | 29 ++-- src/cnaas_nms/db/stackmember.py | 14 +- src/cnaas_nms/db/tests/test_device.py | 12 +- src/cnaas_nms/db/tests/test_git.py | 17 ++- src/cnaas_nms/db/tests/test_mgmtdomain.py | 28 ++-- src/cnaas_nms/devicehandler/cert.py | 10 +- src/cnaas_nms/devicehandler/changescore.py | 9 +- src/cnaas_nms/devicehandler/erase.py | 10 +- src/cnaas_nms/devicehandler/firmware.py | 32 ++-- src/cnaas_nms/devicehandler/get.py | 21 ++- src/cnaas_nms/devicehandler/init_device.py | 112 +++++++------- .../devicehandler/interface_state.py | 2 +- src/cnaas_nms/devicehandler/nornir_helper.py | 6 +- .../nornir_plugins/cnaas_inventory.py | 2 +- src/cnaas_nms/devicehandler/sync_devices.py | 139 +++++++++--------- src/cnaas_nms/devicehandler/sync_history.py | 6 +- src/cnaas_nms/devicehandler/tests/test_get.py | 6 +- .../devicehandler/tests/test_init.py | 4 +- .../devicehandler/tests/test_syncto.py | 2 +- .../devicehandler/tests/test_update.py | 12 +- src/cnaas_nms/devicehandler/update.py | 24 +-- src/cnaas_nms/models/permissions.py | 6 +- src/cnaas_nms/plugins/nav.py | 8 +- src/cnaas_nms/plugins/ni.py | 18 ++- src/cnaas_nms/run.py | 10 +- src/cnaas_nms/scheduler/scheduler.py | 8 +- .../scheduler/tests/test_scheduler.py | 4 +- src/cnaas_nms/scheduler/wrapper.py | 14 +- src/cnaas_nms/scheduler_mule.py | 14 +- src/cnaas_nms/tools/cache.py | 4 +- src/cnaas_nms/tools/dhcp_hook.py | 4 +- src/cnaas_nms/tools/dropdb.py | 2 +- src/cnaas_nms/tools/event.py | 10 +- src/cnaas_nms/tools/initdb.py | 2 +- src/cnaas_nms/tools/jinja_filters.py | 13 +- src/cnaas_nms/tools/jinja_helpers.py | 4 +- src/cnaas_nms/tools/log.py | 24 +-- src/cnaas_nms/tools/oidc/key_management.py | 13 +- src/cnaas_nms/tools/oidc/oidc_client_call.py | 8 +- src/cnaas_nms/tools/oidc/token.py | 3 +- src/cnaas_nms/tools/pki.py | 10 +- src/cnaas_nms/tools/rbac/rbac.py | 13 +- src/cnaas_nms/tools/security.py | 7 +- src/cnaas_nms/tools/testsetup.py | 2 +- 72 files changed, 713 insertions(+), 639 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c02dbc55..afe880b5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,3 +24,21 @@ repos: - id: ruff args: [ --fix, --exit-non-zero-on-fix, --show-fixes ] exclude: (alembic/.*) + - repo: local + hooks: + - id: mypy + name: mypy + language: system + entry: "mypy" + types: [ python ] + require_serial: true + verbose: true + args: + - --no-warn-incomplete-stub + - --ignore-missing-imports + - --no-warn-unused-ignores + - --allow-untyped-decorators + - --no-namespace-packages + - --allow-untyped-globals + - --check-untyped-defs + exclude: "(alembic/.*)|(.*test.*)" diff --git a/src/cnaas_nms/__init__.py b/src/cnaas_nms/__init__.py index 387b2115..63d9b8f2 100644 --- a/src/cnaas_nms/__init__.py +++ b/src/cnaas_nms/__init__.py @@ -6,7 +6,7 @@ def setup_package(): from cnaas_nms.api.tests.app_wrapper import TestAppWrapper app = cnaas_nms.api.app.app - app.wsgi_app = TestAppWrapper(app.wsgi_app, None) + app.wsgi_app = TestAppWrapper(app.wsgi_app, None) # type: ignore client = app.test_client() data = {"action": "refresh"} client.put("/api/v1.0/repository/settings", json=data) diff --git a/src/cnaas_nms/api/app.py b/src/cnaas_nms/api/app.py index b36da21b..6df43c40 100644 --- a/src/cnaas_nms/api/app.py +++ b/src/cnaas_nms/api/app.py @@ -183,7 +183,8 @@ def socketio_on_connect(): if auth_settings.OIDC_ENABLED: try: token = oauth_required.get_token_validator("bearer").authenticate_token(token_string) - user = get_oauth_token_info(token)[auth_settings.OIDC_USERNAME_ATTRIBUTE] + token_info = get_oauth_token_info(token) + user = token_info[auth_settings.OIDC_USERNAME_ATTRIBUTE] except InvalidTokenError as e: logger.debug("InvalidTokenError: " + format(e)) return False @@ -229,21 +230,21 @@ def log_request(response): elif request.method in ["GET", "POST", "PUT", "DELETE", "PATCH"]: try: if auth_settings.OIDC_ENABLED: - token_string = request.headers.get("Authorization").split(" ")[-1] + token_string = str(request.headers.get("Authorization")).split(" ")[-1] token = oauth_required.get_token_validator("bearer").authenticate_token(token_string) token_info = get_oauth_token_info(token) - if auth_settings.OIDC_USERNAME_ATTRIBUTE in token_info: + if token_info is not None and auth_settings.OIDC_USERNAME_ATTRIBUTE in token_info: user = "User: {} ({}), ".format( get_oauth_token_info(token)[auth_settings.OIDC_USERNAME_ATTRIBUTE], auth_settings.OIDC_USERNAME_ATTRIBUTE, ) - elif "client_id" in token_info: + elif token_info is not None and "client_id" in token_info: user = "User: {} (client_id), ".format(get_oauth_token_info(token)["client_id"]) else: logger.warning("Could not get user info from token") raise ValueError else: - token_string = request.headers.get("Authorization").split(" ")[-1] + token_string = str(request.headers.get("Authorization")).split(" ")[-1] user = "User: {}, ".format(decode_token(token_string).get("sub")) except Exception: user = "User: unknown, " diff --git a/src/cnaas_nms/api/auth.py b/src/cnaas_nms/api/auth.py index fcb09437..1eb80b43 100644 --- a/src/cnaas_nms/api/auth.py +++ b/src/cnaas_nms/api/auth.py @@ -88,7 +88,7 @@ def get(self): req = PreparedRequest() req.prepare_url(url, parameters) - resp = redirect(req.url, code=302) + resp = redirect(str(req.url), code=302) if "refresh_token" in token: resp.set_cookie( "REFRESH_TOKEN", @@ -106,7 +106,7 @@ class RefreshApi(Resource): def post(self): oauth_client = current_app.extensions["authlib.integrations.flask_client"] oauth_client_connext: FlaskOAuth2App = oauth_client.connext - token_string = request.headers.get("Authorization").split(" ")[-1] + token_string = str(request.headers.get("Authorization")).split(" ")[-1] oauth_client_connext.token = token_string oauth_client_connext.load_server_metadata() url = oauth_client_connext.server_metadata["token_endpoint"] @@ -155,7 +155,7 @@ def get(self): logger.debug("No permissions defined, so nobody is permitted to do any api calls.") return [] user_info = get_oauth_token_info(current_token) - permissions_of_user = get_permissions_user(permissions_rules, user_info) + permissions_of_user = get_permissions_user(permissions_rules, user_info) # check check # convert to dictionaries so it can be converted to json permissions_as_dics = [] diff --git a/src/cnaas_nms/api/device.py b/src/cnaas_nms/api/device.py index c35ff137..776641df 100644 --- a/src/cnaas_nms/api/device.py +++ b/src/cnaas_nms/api/device.py @@ -1,6 +1,6 @@ import datetime import json -from typing import List, Optional +from typing import Any, List, Optional from flask import make_response, request from flask_restx import Namespace, Resource, fields, marshal @@ -242,7 +242,7 @@ def get(self, device_id): """Get a device from ID""" result = empty_result() result["data"] = {"devices": []} - with sqla_session() as session: + with sqla_session() as session: # type: ignore instance = session.query(Device).filter(Device.id == device_id).one_or_none() if instance: result["data"]["devices"] = device_data_postprocess([instance]) @@ -273,7 +273,8 @@ def delete(self, device_id): return res elif not isinstance(json_data["factory_default"], bool): return empty_result(status="error", data="Argument factory_default must be boolean"), 400 - with sqla_session() as session: + + with sqla_session() as session: # type: ignore dev: Device = session.query(Device).filter(Device.id == device_id).one_or_none() if not dev: return empty_result("error", "Device not found"), 404 @@ -305,7 +306,7 @@ def delete(self, device_id): def put(self, device_id): """Modify device from ID""" json_data = request.get_json() - with sqla_session() as session: + with sqla_session() as session: # type: ignore dev: Device = session.query(Device).filter(Device.id == device_id).one_or_none() dev_prev_state: DeviceState = dev.state @@ -350,7 +351,7 @@ def get(self, hostname): """Get a device from hostname""" result = empty_result() result["data"] = {"devices": []} - with sqla_session() as session: + with sqla_session() as session: # type: ignore instance = session.query(Device).filter(Device.hostname == hostname).one_or_none() if instance: result["data"]["devices"] = device_data_postprocess([instance]) @@ -371,7 +372,7 @@ def post(self): data, errors = Device.validate(**json_data) if errors != []: return empty_result(status="error", data=errors), 400 - with sqla_session() as session: + with sqla_session() as session: # type: ignore instance: Device = session.query(Device).filter(Device.hostname == data["hostname"]).one_or_none() if instance: errors.append("Device already exists") @@ -403,7 +404,7 @@ def get(self): logger.info("started get devices") device_list: List[Device] = [] total_count = 0 - with sqla_session() as session: + with sqla_session() as session: # type: ignore query = session.query(Device, func.count(Device.id).over().label("total")) try: query = build_filter(Device, query) @@ -416,7 +417,7 @@ def get(self): resp = make_response(json.dumps(empty_result(status="success", data=data)), 200) resp.headers["Content-Type"] = "application/json" - resp.headers = {**resp.headers, **pagination_headers(total_count)} + resp.headers = {**resp.headers, **pagination_headers(total_count)} # type: ignore return resp @@ -433,7 +434,7 @@ def post(self, device_id: int): # If device init is already in progress, reschedule a new step2 (connectivity check) # instead of trying to restart initialization - with sqla_session() as session: + with sqla_session() as session: # type: ignore dev: Device = session.query(Device).filter(Device.id == device_id).one_or_none() if ( dev @@ -475,14 +476,14 @@ def post(self, device_id: int): else: return empty_result(status="error", data="Unsupported 'device_type' provided"), 400 - res = empty_result(data=f"Scheduled job to initialize device_id { device_id }") + res = empty_result(data=f"Scheduled job to initialize device_id { str(device_id) }") res["job_id"] = job_id return res @classmethod def arg_check(cls, device_id: int, json_data: dict) -> dict: - parsed_args = {"device_id": device_id} + parsed_args: dict[str, Any] = {"device_id": device_id} if not isinstance(device_id, int): raise ValueError("'device_id' must be an integer") @@ -540,22 +541,23 @@ class DeviceInitCheckApi(Resource): def post(self, device_id: int): """Perform init check on a device""" json_data = request.get_json() - ret = {} + ret: dict[str, Any] = {} linknets_all = [] + mlag_peer_dev: Optional[Device] try: parsed_args = DeviceInitApi.arg_check(device_id, json_data) target_devtype = DeviceType[parsed_args["device_type"]] target_hostname = parsed_args["new_hostname"] mlag_peer_target_hostname: Optional[str] = None mlag_peer_id: Optional[int] = None - mlag_peer_dev: Optional[Device] = None + mlag_peer_dev = None if "mlag_peer_id" in parsed_args and "mlag_peer_new_hostname" in parsed_args: mlag_peer_target_hostname = parsed_args["mlag_peer_new_hostname"] mlag_peer_id = parsed_args["mlag_peer_id"] except ValueError as e: return empty_result(status="error", data="Error parsing arguments: {}".format(e)), 400 - with sqla_session() as session: + with sqla_session() as session: # type: ignore try: dev: Device = cnaas_nms.devicehandler.init_device.pre_init_checks(session, device_id) linknets_all = dev.get_linknets_as_dict(session) @@ -566,7 +568,7 @@ def post(self, device_id: int): if mlag_peer_id: try: - mlag_peer_dev: Device = cnaas_nms.devicehandler.init_device.pre_init_checks(session, mlag_peer_id) + mlag_peer_dev = cnaas_nms.devicehandler.init_device.pre_init_checks(session, mlag_peer_id) linknets_all += mlag_peer_dev.get_linknets_as_dict(session) except ValueError as e: return empty_result(status="error", data="ValueError in pre_init_checks: {}".format(e)), 400 @@ -743,7 +745,7 @@ def post(self): resp = make_response(json.dumps(res), 200) if total_count: - resp.headers["X-Total-Count"] = total_count + resp.headers["X-Total-Count"] = str(total_count) resp.headers["Content-Type"] = "application/json" return resp @@ -787,7 +789,7 @@ def post(self, hostname: str): resp = make_response(json.dumps(res), 200) if total_count: - resp.headers["X-Total-Count"] = total_count + resp.headers["X-Total-Count"] = str(total_count) resp.headers["Content-Type"] = "application/json" return resp @@ -806,7 +808,7 @@ def post(self): hostname = str(json_data["hostname"]) if not Device.valid_hostname(hostname): return empty_result(status="error", data=f"Hostname '{hostname}' is not a valid hostname"), 400 - with sqla_session() as session: + with sqla_session() as session: # type: ignore dev: Device = session.query(Device).filter(Device.hostname == hostname).one_or_none() if not dev or (dev.state != DeviceState.MANAGED and dev.state != DeviceState.UNMANAGED): return ( @@ -828,7 +830,7 @@ def post(self): resp = make_response(json.dumps(res), 200) if total_count: - resp.headers["X-Total-Count"] = total_count + resp.headers["X-Total-Count"] = str(total_count) resp.headers["Content-Type"] = "application/json" return resp @@ -847,7 +849,7 @@ def post(self): hostname = str(json_data["hostname"]) if not Device.valid_hostname(hostname): return empty_result(status="error", data=f"Hostname '{hostname}' is not a valid hostname"), 400 - with sqla_session() as session: + with sqla_session() as session: # type: ignore dev: Device = session.query(Device).filter(Device.hostname == hostname).one_or_none() if not dev or (dev.state != DeviceState.MANAGED and dev.state != DeviceState.UNMANAGED): return ( @@ -873,8 +875,8 @@ def post(self): empty_result(status="error", data=f"Hostname '{mlag_peer_hostname}' is not a valid hostname"), 400, ) - with sqla_session() as session: - dev: Device = session.query(Device).filter(Device.hostname == mlag_peer_hostname).one_or_none() + with sqla_session() as session: # type: ignore + dev = session.query(Device).filter(Device.hostname == mlag_peer_hostname).one_or_none() if not dev or (dev.state != DeviceState.MANAGED and dev.state != DeviceState.UNMANAGED): return ( empty_result( @@ -907,7 +909,7 @@ def post(self): resp = make_response(json.dumps(res), 200) if total_count: - resp.headers["X-Total-Count"] = total_count + resp.headers["X-Total-Count"] = str(total_count) resp.headers["Content-Type"] = "application/json" return resp @@ -956,7 +958,7 @@ def get(self, hostname: str): if not Device.valid_hostname(hostname): return empty_result(status="error", data="Invalid hostname specified"), 400 - with sqla_session() as session: + with sqla_session() as session: # type: ignore dev: Device = session.query(Device).filter(Device.hostname == hostname).one_or_none() if not dev: return empty_result("error", "Device not found"), 404 @@ -987,7 +989,7 @@ def get(self, hostname: str): if not Device.valid_hostname(hostname): return empty_result(status="error", data="Invalid hostname specified"), 400 - kwargs = {} + kwargs: dict[str, Any] = {} if "job_id" in args: try: kwargs["job_id"] = int(args["job_id"]) @@ -1004,7 +1006,7 @@ def get(self, hostname: str): except Exception: return empty_result("error", "before must be a valid ISO format date time string"), 400 - with sqla_session() as session: + with sqla_session() as session: # type: ignore try: result["data"] = Job.get_previous_config(session, hostname, **kwargs) except JobNotFoundError as e: @@ -1021,7 +1023,7 @@ def get(self, hostname: str): def post(self, hostname: str): """Restore configuration to previous version""" json_data = request.get_json() - apply_kwargs = {"hostname": hostname} + apply_kwargs: dict[str, Any] = {"hostname": hostname} config = None if not Device.valid_hostname(hostname): return empty_result(status="error", data="Invalid hostname specified"), 400 @@ -1034,7 +1036,7 @@ def post(self, hostname: str): else: return empty_result("error", "job_id must be specified"), 400 - with sqla_session() as session: + with sqla_session() as session: # type: ignore try: prev_config_result = Job.get_previous_config(session, hostname, job_id=job_id) failed = prev_config_result["failed"] @@ -1080,7 +1082,7 @@ class DeviceApplyConfigApi(Resource): def post(self, hostname: str): """Apply exact specified configuration to device without using templates""" json_data = request.get_json() - apply_kwargs = {"hostname": hostname} + apply_kwargs: dict[str, Any] = {"hostname": hostname} allow_live_run = api_settings.ALLOW_APPLY_CONFIG_LIVERUN if not Device.valid_hostname(hostname): return empty_result(status="error", data="Invalid hostname specified"), 400 @@ -1165,7 +1167,7 @@ def post(self): resp = make_response(json.dumps(res), 200) if total_count: - resp.headers["X-Total-Count"] = total_count + resp.headers["X-Total-Count"] = str(total_count) resp.headers["Content-Type"] = "application/json" return resp else: @@ -1177,7 +1179,7 @@ class DeviceStackmembersApi(Resource): def get(self, hostname): """Get stackmembers for device""" result = empty_result(data={"stackmembers": []}) - with sqla_session() as session: + with sqla_session() as session: # type: ignore device = session.query(Device).filter(Device.hostname == hostname).one_or_none() if not device: return empty_result("error", "Device not found"), 404 @@ -1196,7 +1198,7 @@ def put(self, hostname): errors = DeviceStackmembersApi.format_errors(e.errors()) return empty_result("error", errors), 400 result = empty_result(data={"stackmembers": []}) - with sqla_session() as session: + with sqla_session() as session: # type: ignore device_instance = session.query(Device).filter(Device.hostname == hostname).one_or_none() if not device_instance: return empty_result("error", "Device not found"), 404 @@ -1232,13 +1234,14 @@ def get(self): args = request.args result = empty_result() result["data"] = {"hostnames": {}} + sync_history: SyncHistory if "hostname" in args: if not Device.valid_hostname(args["hostname"]): return empty_result(status="error", data="Invalid hostname specified"), 400 - sync_history: SyncHistory = get_sync_events([args["hostname"]]) + sync_history = get_sync_events([args["hostname"]]) else: - sync_history: SyncHistory = get_sync_events() + sync_history = get_sync_events() result["data"]["hostnames"] = sync_history.asdict() return result @@ -1250,7 +1253,7 @@ def post(self): validated_json_data = NewSyncEventModel(**request.get_json()).model_dump() except ValidationError as e: return empty_result("error", parse_pydantic_error(e, NewSyncEventModel, request.get_json())), 400 - with sqla_session() as session: + with sqla_session() as session: # type: ignore device_instance = ( session.query(Device).filter(Device.hostname == validated_json_data["hostname"]).one_or_none() ) diff --git a/src/cnaas_nms/api/firmware.py b/src/cnaas_nms/api/firmware.py index 78ca8e72..972f05f6 100644 --- a/src/cnaas_nms/api/firmware.py +++ b/src/cnaas_nms/api/firmware.py @@ -1,6 +1,6 @@ import json from datetime import datetime -from typing import Optional +from typing import Any, Optional import requests from flask import make_response, request @@ -60,18 +60,18 @@ def get_firmware(**kwargs: dict) -> str: return "Could not download firmware: " + str(e) if json_data["status"] == "error": return json_data["message"] - return "File downloaded from: " + kwargs["url"] + return "File downloaded from: " + str(kwargs["url"]) @job_wrapper def get_firmware_chksum(**kwargs: dict) -> str: try: - url = api_settings.HTTPD_URL + "/" + kwargs["filename"] + url = api_settings.HTTPD_URL + "/" + str(kwargs["filename"]) res = requests.get(url, verify=api_settings.VERIFY_TLS) json_data = json.loads(res.content) except Exception as e: logger.exception(f"Exceptionb while getting checksum: {e}") - return "Failed to get checksum for " + kwargs["filename"] + return "Failed to get checksum for " + str(kwargs["filename"]) if json_data["status"] == "error": return json_data["message"] return json_data["data"]["file"]["sha1"] @@ -80,21 +80,21 @@ def get_firmware_chksum(**kwargs: dict) -> str: @job_wrapper def remove_file(**kwargs: dict) -> str: try: - url = api_settings.HTTPD_URL + "/" + kwargs["filename"] + url = api_settings.HTTPD_URL + "/" + str(kwargs["filename"]) res = requests.delete(url, verify=api_settings.VERIFY_TLS) json_data = json.loads(res.content) except Exception as e: logger.exception(f"Exception when removing firmware: {e}") return "Failed to remove file" if json_data["status"] == "error": - return "Failed to remove file " + kwargs["filename"] - return "File " + kwargs["filename"] + " removed" + return "Failed to remove file " + str(kwargs["filename"]) + return "File " + str(kwargs["filename"]) + " removed" class FirmwareApi(Resource): @login_required @api.expect(firmware_model) - def post(self) -> tuple: + def post(self) -> dict[str, Any]: """Download new firmware""" json_data = request.get_json() @@ -113,7 +113,7 @@ def post(self) -> tuple: kwargs["sha1"] = json_data["sha1"] kwargs["verify_tls"] = json_data["verify_tls"] - scheduler = Scheduler() + scheduler: Scheduler = Scheduler() job_id = scheduler.add_onetime_job( "cnaas_nms.api.firmware:get_firmware", when=1, scheduled_by=get_identity(), kwargs=kwargs ) @@ -123,7 +123,7 @@ def post(self) -> tuple: return res @login_required - def get(self) -> tuple: + def get(self) -> dict[str, Any] | tuple[dict[str, Any], int]: """Get firmwares""" try: res = requests.get(api_settings.HTTPD_URL, verify=api_settings.VERIFY_TLS) @@ -138,7 +138,7 @@ class FirmwareImageApi(Resource): @login_required def get(self, filename: str) -> dict: """Get information about a single firmware""" - scheduler = Scheduler() + scheduler: Scheduler = Scheduler() job_id = scheduler.add_onetime_job( "cnaas_nms.api.firmware:get_firmware_chksum", when=1, @@ -153,7 +153,7 @@ def get(self, filename: str) -> dict: @login_required def delete(self, filename: str) -> dict: """Remove firmware""" - scheduler = Scheduler() + scheduler: Scheduler = Scheduler() job_id = scheduler.add_onetime_job( "cnaas_nms.api.firmware:remove_file", when=1, scheduled_by=get_identity(), kwargs={"filename": filename} ) @@ -170,7 +170,7 @@ def post(self): """Upgrade firmware on device""" json_data = request.get_json() - kwargs = dict() + kwargs: dict[str, Any] = dict() seconds = 1 date_format = "%Y-%m-%d %H:%M:%S" url = api_settings.FIRMWARE_URL @@ -272,7 +272,7 @@ def post(self): logger.exception(f"Exception when scheduling job: {e}") return empty_result(status="error", data=f"Invalid date format, should be: {date_format}") - scheduler = Scheduler() + scheduler: Scheduler = Scheduler() job_id = scheduler.add_onetime_job( "cnaas_nms.devicehandler.firmware:device_upgrade", when=seconds, @@ -284,7 +284,7 @@ def post(self): resp = make_response(json.dumps(res), 200) if total_count: - resp.headers["X-Total-Count"] = total_count + resp.headers["X-Total-Count"] = str(total_count) resp.headers["Content-Type"] = "application/json" return resp diff --git a/src/cnaas_nms/api/generic.py b/src/cnaas_nms/api/generic.py index fb1c03ae..9b307555 100644 --- a/src/cnaas_nms/api/generic.py +++ b/src/cnaas_nms/api/generic.py @@ -1,10 +1,11 @@ import math import re import urllib -from typing import List +from typing import Any, Dict, List import sqlalchemy from flask import request +from pydantic import ValidationError from cnaas_nms.db.settings import get_pydantic_error_value, get_pydantic_field_descr @@ -151,7 +152,7 @@ def build_filter(f_class, query: sqlalchemy.orm.query.Query): query = query.filter(f_class_op(value)) - if f_class_order_by_field: + if f_class_order_by_field and order: query = query.order_by(order(f_class_order_by_field)) else: if "id" in f_class.__table__._columns.keys(): @@ -163,14 +164,16 @@ def build_filter(f_class, query: sqlalchemy.orm.query.Query): return query -def empty_result(status="success", data=None): +def empty_result(status="success", data=None) -> Dict[str, Any]: if status == "success": return {"status": status, "data": data} elif status == "error": return {"status": status, "message": data if data else "Unknown error"} + else: + return {} -def parse_pydantic_error(e: Exception, schema, data: dict) -> List[str]: +def parse_pydantic_error(e: ValidationError, schema, data: dict) -> List[str]: errors = [] for num, error in enumerate(e.errors()): loc = error["loc"] diff --git a/src/cnaas_nms/api/groups.py b/src/cnaas_nms/api/groups.py index cd7a84ce..821b7e85 100644 --- a/src/cnaas_nms/api/groups.py +++ b/src/cnaas_nms/api/groups.py @@ -17,8 +17,8 @@ def groups_populate(group_name: Optional[str] = None) -> dict: if group_name: tmpgroups: dict = {group_name: []} else: - tmpgroups: dict = {key: [] for key in get_groups()} - with sqla_session() as session: + tmpgroups = {key: [] for key in get_groups()} + with sqla_session() as session: # type: ignore devices: List[Device] = session.query(Device).all() for dev in devices: groups = get_groups(dev.hostname) @@ -52,7 +52,7 @@ def groups_osversion_populate(group_name: str): else: raise ValueError("Could not find group {}".format(group_name)) - with sqla_session() as session: + with sqla_session() as session: # type: ignore devices: List[Device] = ( session.query(Device).filter(Device.state == DeviceState.MANAGED).order_by(Device.hostname.asc()).all() ) diff --git a/src/cnaas_nms/api/interface.py b/src/cnaas_nms/api/interface.py index bf881338..d6687f36 100644 --- a/src/cnaas_nms/api/interface.py +++ b/src/cnaas_nms/api/interface.py @@ -1,4 +1,4 @@ -from typing import List +from typing import Any, List from flask import request from flask_restx import Namespace, Resource, fields @@ -69,7 +69,7 @@ def get(self, hostname): """List all interfaces""" result = empty_result() result["data"] = {"interfaces": []} - with sqla_session() as session: + with sqla_session() as session: # type: ignore dev = session.query(Device).filter(Device.hostname == hostname).one_or_none() if not dev: return empty_result("error", "Device not found"), 404 @@ -95,8 +95,7 @@ def put(self, hostname): data = {} errors = [] device_settings = None - - with sqla_session() as session: + with sqla_session() as session: # type: ignore dev: Device = session.query(Device).filter(Device.hostname == hostname).one_or_none() if not dev: return empty_result("error", "Device not found"), 404 @@ -117,8 +116,8 @@ def put(self, hostname): errors.append(f"Interface {if_name} not found") continue if intf.data and isinstance(intf.data, dict): - intfdata_original = dict(intf.data) - intfdata = dict(intf.data) + intfdata_original: dict[str, Any] = dict(intf.data) + intfdata: dict[str, Any] = dict(intf.data) else: intfdata_original = {} intfdata = {} @@ -130,8 +129,8 @@ def put(self, hostname): errors.append("configtype is not a string") else: if InterfaceConfigType.has_name(configtype): - if intf.configtype != InterfaceConfigType[configtype]: - intf.configtype = InterfaceConfigType[configtype] + if intf.configtype != str(InterfaceConfigType[configtype]): + intf.configtype = str(InterfaceConfigType[configtype]) updated = True data[if_name] = {"configtype": configtype} else: @@ -272,7 +271,7 @@ def put(self, hostname): "cli_append_str must be a string, got: {}".format(if_dict["data"]["cli_append_str"]) ) elif "data" in if_dict and not if_dict["data"]: - intfdata = None + intfdata = {} if intfdata != intfdata_original: intf.data = intfdata diff --git a/src/cnaas_nms/api/jobs.py b/src/cnaas_nms/api/jobs.py index 6f0e854b..fcbbc500 100644 --- a/src/cnaas_nms/api/jobs.py +++ b/src/cnaas_nms/api/jobs.py @@ -1,5 +1,6 @@ import json import time +from typing import Any from flask import make_response, request from flask_restx import Namespace, Resource, fields @@ -61,10 +62,10 @@ class JobsApi(Resource): @login_required def get(self): """Get one or more jobs""" - data = {"jobs": []} + data: dict[str, Any] = {"jobs": []} total_count = 0 args = request.args - with sqla_session() as session: + with sqla_session() as session: # type: ignore query = session.query(Job, func.count(Job.id).over().label("total")) try: query = build_filter(Job, query) @@ -78,7 +79,7 @@ def get(self): resp = make_response(json.dumps(empty_result(status="success", data=data)), 200) resp.headers["Content-Type"] = "application/json" - resp.headers = {**resp.headers, **pagination_headers(total_count)} + resp.headers = {**resp.headers, **pagination_headers(total_count)} # type: ignore return resp @@ -87,7 +88,7 @@ class JobByIdApi(Resource): def get(self, job_id): """Get job information by ID""" args = request.args - with sqla_session() as session: + with sqla_session() as session: # type: ignore job = session.query(Job).filter(Job.id == job_id).one_or_none() if job: job_dict = job.as_dict() @@ -102,7 +103,7 @@ def put(self, job_id): if "action" not in json_data: return empty_result(status="error", data="Action must be specified"), 400 - with sqla_session() as session: + with sqla_session() as session: # type: ignore job = session.query(Job).filter(Job.id == job_id).one_or_none() if not job: return empty_result(status="error", data="No job with id {} found".format(job_id)), 400 @@ -132,11 +133,11 @@ def put(self, job_id): scheduler.remove_scheduled_job(job_id=job_id, abort_message=abort_reason) time.sleep(2) elif job_status == JobStatus.RUNNING: - with sqla_session() as session: + with sqla_session() as session: # type: ignore job = session.query(Job).filter(Job.id == job_id).one_or_none() job.status = JobStatus.ABORTING - with sqla_session() as session: + with sqla_session() as session: # type: ignore job = session.query(Job).filter(Job.id == job_id).one_or_none() return empty_result(data={"jobs": [job.as_dict()]}) else: @@ -148,7 +149,7 @@ class JobLockApi(Resource): def get(self): """Get job locks""" locks = [] - with sqla_session() as session: + with sqla_session() as session: # type: ignore for lock in session.query(Joblock).all(): locks.append(lock.as_dict()) return empty_result("success", data={"locks": locks}) @@ -161,7 +162,7 @@ def delete(self): if "name" not in json_data or not json_data["name"]: return empty_result("error", "No lock name specified"), 400 - with sqla_session() as session: + with sqla_session() as session: # type: ignore lock = session.query(Joblock).filter(Joblock.name == json_data["name"]).one_or_none() if lock: session.delete(lock) diff --git a/src/cnaas_nms/api/json.py b/src/cnaas_nms/api/json.py index 0dcf9968..c1f6895a 100644 --- a/src/cnaas_nms/api/json.py +++ b/src/cnaas_nms/api/json.py @@ -9,4 +9,4 @@ def default(self, o): if isinstance(o, _IPAddressBase): return str(o) else: - return super().default(self, o) + return super().default(o) diff --git a/src/cnaas_nms/api/linknet.py b/src/cnaas_nms/api/linknet.py index 95bbe189..1430e714 100644 --- a/src/cnaas_nms/api/linknet.py +++ b/src/cnaas_nms/api/linknet.py @@ -1,9 +1,9 @@ from ipaddress import IPv4Address, IPv4Network -from typing import Optional +from typing import Any, Optional from flask import request from flask_restx import Namespace, Resource, fields -from pydantic import BaseModel, FieldValidationInfo, ValidationError, field_validator +from pydantic import BaseModel, ValidationError, ValidationInfo, field_validator from cnaas_nms.api.generic import empty_result, parse_pydantic_error, update_sqla_object from cnaas_nms.db.device import Device, DeviceType @@ -52,7 +52,7 @@ class f_linknet(BaseModel): @field_validator("device_a_ip", "device_b_ip") @classmethod - def device_ip_validator(cls, v, info: FieldValidationInfo): + def device_ip_validator(cls, v, info: ValidationInfo): if not v: return v if not info.data["ipv4_network"]: @@ -94,7 +94,7 @@ class LinknetsApi(Resource): def validate_hostname(hostname): if not Device.valid_hostname(hostname): raise ValueError("Invalid hostname: {}".format(hostname)) - with sqla_session() as session: + with sqla_session() as session: # type: ignore dev: Device = session.query(Device).filter(Device.hostname == hostname).one_or_none() if not dev: raise ValueError("Hostname {} not found in database") @@ -104,8 +104,8 @@ def validate_hostname(hostname): @login_required def get(self): """Get all linksnets""" - result = {"linknets": []} - with sqla_session() as session: + result: dict[str, Any] = {"linknets": []} + with sqla_session() as session: # type: ignore query = session.query(Linknet) for instance in query: result["linknets"].append(instance.as_dict()) @@ -143,7 +143,7 @@ def post(self): if errors: return empty_result(status="error", data=errors), 400 - with sqla_session() as session: + with sqla_session() as session: # type: ignore dev_a: Device = session.query(Device).filter(Device.hostname == json_data["device_a"]).one_or_none() if not dev_a: return empty_result(status="error", data="Hostname '{}' not found".format(json_data["device_a"])), 500 @@ -195,7 +195,7 @@ def delete(self): if errors: return empty_result(status="error", data=errors), 400 - with sqla_session() as session: + with sqla_session() as session: # type: ignore cur_linknet: Linknet = session.query(Linknet).filter(Linknet.id == json_data["id"]).one_or_none() if not cur_linknet: return empty_result(status="error", data="No such linknet found in database"), 404 @@ -214,7 +214,7 @@ def get(self, linknet_id): """Get a single specified linknet""" result = empty_result() result["data"] = {"linknets": []} - with sqla_session() as session: + with sqla_session() as session: # type: ignore instance = session.query(Linknet).filter(Linknet.id == linknet_id).one_or_none() if instance: result["data"]["linknets"].append(instance.as_dict()) @@ -225,7 +225,7 @@ def get(self, linknet_id): @login_required def delete(self, linknet_id): """Remove a linknet""" - with sqla_session() as session: + with sqla_session() as session: # type: ignore instance: Linknet = session.query(Linknet).filter(Linknet.id == linknet_id).one_or_none() if instance: instance.device_a.synchronized = False @@ -254,7 +254,7 @@ def put(self, linknet_id): if errors: return empty_result(status="error", data=errors), 400 - with sqla_session() as session: + with sqla_session() as session: # type: ignore instance: Linknet = session.query(Linknet).filter(Linknet.id == linknet_id).one_or_none() if instance: try: diff --git a/src/cnaas_nms/api/mgmtdomain.py b/src/cnaas_nms/api/mgmtdomain.py index 780328fd..adb3c08e 100644 --- a/src/cnaas_nms/api/mgmtdomain.py +++ b/src/cnaas_nms/api/mgmtdomain.py @@ -88,7 +88,7 @@ def get(self, mgmtdomain_id): """Get management domain by ID""" result = empty_result() result["data"] = {"mgmtdomains": []} - with sqla_session() as session: + with sqla_session() as session: # type: ignore instance = session.query(Mgmtdomain).filter(Mgmtdomain.id == mgmtdomain_id).one_or_none() if instance: result["data"]["mgmtdomains"].append(instance.as_dict()) @@ -99,7 +99,7 @@ def get(self, mgmtdomain_id): @login_required def delete(self, mgmtdomain_id): """Remove management domain""" - with sqla_session() as session: + with sqla_session() as session: # type: ignore instance: Mgmtdomain = session.query(Mgmtdomain).filter(Mgmtdomain.id == mgmtdomain_id).one_or_none() if instance: instance.device_a.synchronized = False @@ -126,7 +126,7 @@ def put(self, mgmtdomain_id): if errors: return empty_result("error", errors), 400 - with sqla_session() as session: + with sqla_session() as session: # type: ignore instance: Mgmtdomain = session.query(Mgmtdomain).filter(Mgmtdomain.id == mgmtdomain_id).one_or_none() if instance: changed: bool = update_sqla_object(instance, json_data) @@ -148,7 +148,7 @@ def get(self): """Get all management domains""" result = empty_result() result["data"] = {"mgmtdomains": []} - with sqla_session() as session: + with sqla_session() as session: # type: ignore query = session.query(Mgmtdomain) try: query = build_filter(Mgmtdomain, query).limit(limit_results()) @@ -165,7 +165,7 @@ def post(self): json_data = request.get_json() data = {} errors = [] - with sqla_session() as session: + with sqla_session() as session: # type: ignore if "device_a" in json_data: hostname_a = str(json_data["device_a"]) if not Device.valid_hostname(hostname_a): @@ -212,7 +212,7 @@ def post(self): session.flush() except IntegrityError as e: session.rollback() - if "duplicate" in str(e): + if "duplicate" in str(e) and e.orig: return empty_result("error", "Duplicate value: {}".format(e.orig.args[0])), 400 else: return empty_result("error", "Integrity error: {}".format(e)), 400 diff --git a/src/cnaas_nms/api/models/stackmembers_model.py b/src/cnaas_nms/api/models/stackmembers_model.py index 37049623..46c817a7 100644 --- a/src/cnaas_nms/api/models/stackmembers_model.py +++ b/src/cnaas_nms/api/models/stackmembers_model.py @@ -1,12 +1,12 @@ from typing import List, Optional -from pydantic import BaseModel, conint, field_validator +from pydantic import BaseModel, Field, field_validator class StackmemberModel(BaseModel): - member_no: Optional[conint(gt=-1)] = None + member_no: Optional[int] = Field(None, gt=-1) hardware_id: str - priority: Optional[conint(gt=-1)] = None + priority: Optional[int] = Field(None, gt=-1) @field_validator("hardware_id") @classmethod diff --git a/src/cnaas_nms/api/settings.py b/src/cnaas_nms/api/settings.py index 88a5ba00..3e3c7d5f 100644 --- a/src/cnaas_nms/api/settings.py +++ b/src/cnaas_nms/api/settings.py @@ -34,7 +34,7 @@ def get(self): hostname = args["hostname"] else: return empty_result("error", "Invalid hostname specified"), 400 - with sqla_session() as session: + with sqla_session() as session: # type: ignore dev: Device = session.query(Device).filter(Device.hostname == hostname).one_or_none() if dev: device_type = dev.device_type diff --git a/src/cnaas_nms/api/tests/test_device.py b/src/cnaas_nms/api/tests/test_device.py index 1e307378..c524b0f9 100644 --- a/src/cnaas_nms/api/tests/test_device.py +++ b/src/cnaas_nms/api/tests/test_device.py @@ -22,7 +22,7 @@ def requirements(self, postgresql, settings_directory): pass def cleandb(self): - with sqla_session() as session: + with sqla_session() as session: # type: ignore for hardware_id in ["AB1234", "CD5555", "GF43534"]: stack = session.query(Stackmember).filter(Stackmember.hardware_id == hardware_id).one_or_none() if stack: @@ -53,7 +53,7 @@ def tearDown(self): self.cleandb() def add_device(self): - with sqla_session() as session: + with sqla_session() as session: # type: ignore device = Device( hostname="testdevice", platform="eos", @@ -110,14 +110,14 @@ def test_modify_device(self): json_data = json.loads(result.data.decode()) updated_device = json_data["data"]["updated_device"] self.assertEqual(modify_data["description"], updated_device["description"]) - with sqla_session() as session: + with sqla_session() as session: # type: ignore q_device = session.query(Device).filter(Device.hostname == self.hostname).one_or_none() self.assertEqual(modify_data["description"], q_device.description) def test_delete_device(self): result = self.client.delete(f"/api/v1.0/device/{self.device_id}") self.assertEqual(result.status_code, 200) - with sqla_session() as session: + with sqla_session() as session: # type: ignore q_device = session.query(Device).filter(Device.hostname == self.hostname).one_or_none() self.assertIsNone(q_device) @@ -146,7 +146,7 @@ def test_get_stackmembers_no_stackmembers(self): self.assertEqual(json_data["data"]["stackmembers"], []) def test_get_stackmembers(self): - with sqla_session() as session: + with sqla_session() as session: # type: ignore stackmember = Stackmember(device_id=self.device_id, hardware_id="AB1234", member_no=1, priority=3) session.add(stackmember) result = self.client.get(f"/api/v1.0/device/{self.hostname}/stackmember") @@ -167,7 +167,7 @@ def test_put_stackmembers_valid(self): json_data = json.loads(result.data.decode()) self.assertEqual(result.status_code, 200, msg=json_data) self.assertEqual(len(json_data["data"]["stackmembers"]), 3, msg=json_data) - with sqla_session() as session: + with sqla_session() as session: # type: ignore q_stackmembers = session.query(Stackmember).filter(Stackmember.device_id == self.device_id).all() self.assertEqual(len(q_stackmembers), 3, msg=json_data) @@ -187,7 +187,7 @@ def test_put_stackmembers_invalid_hardware_id(self): self.assertEqual(result.status_code, 400) def test_put_stackmembers_clear(self): - with sqla_session() as session: + with sqla_session() as session: # type: ignore stackmember = Stackmember( device_id=self.device_id, hardware_id="AB1234", @@ -200,7 +200,7 @@ def test_put_stackmembers_clear(self): json_data = json.loads(result.data.decode()) self.assertEqual(result.status_code, 200) self.assertEqual(len(json_data["data"]["stackmembers"]), 0) - with sqla_session() as session: + with sqla_session() as session: # type: ignore q_stackmembers = session.query(Stackmember).filter(Stackmember.device_id == self.device_id).all() self.assertEqual(len(q_stackmembers), 0) diff --git a/src/cnaas_nms/app_settings.py b/src/cnaas_nms/app_settings.py index ce64b238..f1242161 100644 --- a/src/cnaas_nms/app_settings.py +++ b/src/cnaas_nms/app_settings.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Optional +from typing import Any, Optional import yaml from pydantic import field_validator @@ -29,9 +29,9 @@ class AppSettings(BaseSettings): USERNAME_MANAGED: str = "admin" PASSWORD_MANAGED: str = "abc123abc123" TEMPLATES_REMOTE: str = "/opt/git/cnaas-templates-origin.git" - TEMPLATES_LOCAL: str = "/opt/cnaas/templates" + TEMPLATES_LOCAL: str = "/opt/git/cnaas-templates" SETTINGS_REMOTE: str = "/opt/git/cnaas-settings-origin.git" - SETTINGS_LOCAL: str = "/opt/cnaas/settings" + SETTINGS_LOCAL: str = "/opt/git/cnaas-settings" class ApiSettings(BaseSettings): @@ -184,7 +184,7 @@ def construct_auth_settings() -> AuthSettings: auth_settings.AUDIENCE = auth_settings.OIDC_CLIENT_ID if auth_settings.PERMISSIONS_DISABLED: - permissions_rules = { + permissions_rules: dict[str, Any] = { "config": {"default_permissions": "default"}, "roles": { "default": {"permissions": [{"methods": ["*"], "endpoints": ["*"], "pages": ["*"], "rights": ["*"]}]} diff --git a/src/cnaas_nms/db/device.py b/src/cnaas_nms/db/device.py index b9cc7b9e..5bc0557c 100644 --- a/src/cnaas_nms/db/device.py +++ b/src/cnaas_nms/db/device.py @@ -5,7 +5,7 @@ import ipaddress import json import re -from typing import List, Optional, Set +from typing import List, Optional, Set, Tuple from sqlalchemy import Boolean, DateTime, Enum, ForeignKey, Integer, String, Unicode, UniqueConstraint, event from sqlalchemy.orm import Mapped, mapped_column, relationship @@ -115,12 +115,14 @@ def as_dict(self) -> dict: d[col.name] = value return d - def get_neighbors(self, session, linknets: Optional[List[dict]] = None) -> Set[Device]: + def get_neighbors( + self, session, linknets: Optional[List[dict]] | Optional[List[cnaas_nms.db.linknet.Linknet]] = None + ) -> Set[Device]: """Look up neighbors from cnaas_nms.db.linknet.Linknets and return them as a list of Device objects.""" if not linknets: linknets = self.get_linknets(session) ret: Set = set() - for linknet in linknets: + for linknet in linknets: # type: ignore if isinstance(linknet, cnaas_nms.db.linknet.Linknet): device_a_id = linknet.device_a_id device_b_id = linknet.device_b_id @@ -188,8 +190,8 @@ def get_links_to(self, session, peer_device: Device) -> List[cnaas_nms.db.linkne ) def get_neighbor_ifnames( - self, session, peer_device: Device, linknets_arg: [Optional[List[dict]]] = None - ) -> List[(str, str)]: + self, session, peer_device: Device, linknets_arg: Optional[List[dict]] = None + ) -> List[Tuple[str, str]]: """Get the interface names connecting self device with peer device. Returns: @@ -231,6 +233,7 @@ def get_neighbor_local_ipif(self, session, peer_device: Device) -> Optional[str] return "{}/{}".format(linknet.device_a_ip, ipaddress.IPv4Network(linknet.ipv4_network).prefixlen) elif linknet.device_b_id == self.id: return "{}/{}".format(linknet.device_b_ip, ipaddress.IPv4Network(linknet.ipv4_network).prefixlen) + return None def get_neighbor_ip(self, session, peer_device: Device): """Get the remote peer IP address for the linknet going towards device.""" diff --git a/src/cnaas_nms/db/git.py b/src/cnaas_nms/db/git.py index e4ce24b3..ba4b8971 100644 --- a/src/cnaas_nms/db/git.py +++ b/src/cnaas_nms/db/git.py @@ -7,7 +7,6 @@ from urllib.parse import urldefrag import yaml -from git.exc import GitCommandError, NoSuchPathError from cnaas_nms.app_settings import app_settings from cnaas_nms.db.device import Device, DeviceType @@ -29,6 +28,7 @@ from cnaas_nms.tools.event import add_event from cnaas_nms.tools.log import get_logger from git import InvalidGitRepositoryError, Repo +from git.exc import GitCommandError, NoSuchPathError class RepoType(enum.Enum): diff --git a/src/cnaas_nms/db/helper.py b/src/cnaas_nms/db/helper.py index ce52d75e..803dd56a 100644 --- a/src/cnaas_nms/db/helper.py +++ b/src/cnaas_nms/db/helper.py @@ -18,57 +18,34 @@ def canonical_mac(mac): return str(na_mac) -def find_mgmtdomain(session, hostnames: List[str]) -> Optional[Mgmtdomain]: - """Find the corresponding management domain for a pair of - distribution switches. +def find_mgmtdomain_one_device(session, device0: Device) -> Optional[Mgmtdomain]: + if device0.device_type == DeviceType.DIST: + mgmtdomain = ( + session.query(Mgmtdomain) + .filter((Mgmtdomain.device_a == device0) | (Mgmtdomain.device_b == device0)) + .limit(1) + .one_or_none() + ) + if not mgmtdomain: + raise Exception("No mgmtdomain found for uplink device: {}".format(device0.hostname)) + elif device0.device_type == DeviceType.ACCESS: + if device0.management_ip: + mgmtdomain = find_mgmtdomain_by_ip(session, IPv4Address(device0.management_ip)) + else: + raise Exception("No mgmtdomain found for uplink device: {}".format(device0.hostname)) + else: + raise Exception("Unexpected uplink device type: {}".format(device0.device_type)) + return mgmtdomain - Args: - hostnames: A list of one or two hostnames of uplink devices - Raises: - ValueError: On invalid hostnames etc - Exception: General exceptions - """ - mgmtdomain: Optional[Mgmtdomain] = None - if not isinstance(hostnames, list) or not 1 <= len(hostnames) <= 2: +def find_mgmtdomain_two_devices(session, device0: Device, device1: Device) -> Optional[Mgmtdomain]: + if device0.device_type != device1.device_type: raise ValueError( - "One or two uplink devices are required to find a compatible mgmtdomain, got: {}".format(hostnames) + "Both uplink devices must be of same device type: {}, {}".format(device0.hostname, device1.hostname) ) - for hostname in hostnames: - if not Device.valid_hostname(hostname): - raise ValueError(f"Argument {hostname} is not a valid hostname") - try: - device0: Device = session.query(Device).filter(Device.hostname == hostnames[0]).one() - except NoResultFound: - raise ValueError(f"hostname {hostnames[0]} not found in device database") - - if len(hostnames) == 2: - try: - device1: Optional[Device] = session.query(Device).filter(Device.hostname == hostnames[1]).one() - except NoResultFound: - raise ValueError(f"hostname {hostnames[1]} not found in device database") - else: - device1: Optional[Device] = None - - if len(hostnames) == 1: - if device0.device_type == DeviceType.DIST: - mgmtdomain: Optional[Mgmtdomain] = ( - session.query(Mgmtdomain) - .filter((Mgmtdomain.device_a == device0) | (Mgmtdomain.device_b == device0)) - .limit(1) - .one_or_none() - ) - if not mgmtdomain: - raise Exception("No mgmtdomain found for uplink device: {}".format(device0.hostname)) - elif device0.device_type == DeviceType.ACCESS: - mgmtdomain: Optional[Mgmtdomain] = find_mgmtdomain_by_ip(session, device0.management_ip) - elif device0.device_type == DeviceType.DIST or device1.device_type == DeviceType.DIST: - if device0.device_type != DeviceType.DIST or device1.device_type != DeviceType.DIST: - raise ValueError( - "Both uplink devices must be of same device type: {}, {}".format(device0.hostname, device1.hostname) - ) + elif device0.device_type == DeviceType.DIST: try: - mgmtdomain: Mgmtdomain = ( + mgmtdomain = ( session.query(Mgmtdomain) .filter( ((Mgmtdomain.device_a == device0) & (Mgmtdomain.device_b == device1)) @@ -79,7 +56,7 @@ def find_mgmtdomain(session, hostnames: List[str]) -> Optional[Mgmtdomain]: # If no mgmtdomain has been found, check if there is exactly one mgmtdomain # defined that has two core devices as members and use that instead if not mgmtdomain: - mgmtdomain: Mgmtdomain = ( + mgmtdomain = ( session.query(Mgmtdomain) .filter( (Mgmtdomain.device_a.has(Device.device_type == DeviceType.CORE)) @@ -89,17 +66,16 @@ def find_mgmtdomain(session, hostnames: List[str]) -> Optional[Mgmtdomain]: ) except MultipleResultsFound: raise Exception("Found multiple possible mgmtdomains, please remove any redundant mgmtdomains") - elif device0.device_type == DeviceType.ACCESS or device1.device_type == DeviceType.ACCESS: - if device0.device_type != DeviceType.ACCESS or device1.device_type != DeviceType.ACCESS: - raise ValueError( - "Both uplink devices must be of same device type: {}, {}".format(device0.hostname, device1.hostname) - ) - mgmtdomain0: Optional[Mgmtdomain] = find_mgmtdomain_by_ip(session, device0.management_ip) - mgmtdomain1: Optional[Mgmtdomain] = find_mgmtdomain_by_ip(session, device1.management_ip) + elif device0.device_type == DeviceType.ACCESS: + mgmtdomain0: Optional[Mgmtdomain] = find_mgmtdomain_by_ip(session, IPv4Address(device0.management_ip)) + mgmtdomain1: Optional[Mgmtdomain] = find_mgmtdomain_by_ip(session, IPv4Address(device1.management_ip)) if not mgmtdomain0 or not mgmtdomain1: raise Exception( "Uplink access devices are missing mgmtdomains: {}: {}, {}: {}".format( - device0.hostname, mgmtdomain0.ipv4_gw, device1.hostname, mgmtdomain1.ipv4_gw + device0.hostname, + mgmtdomain0.ipv4_gw if mgmtdomain0 else "", + device1.hostname, + mgmtdomain1.ipv4_gw if mgmtdomain1 else "", ) ) elif mgmtdomain0.id != mgmtdomain1.id: @@ -113,6 +89,42 @@ def find_mgmtdomain(session, hostnames: List[str]) -> Optional[Mgmtdomain]: return mgmtdomain +def find_mgmtdomain(session, hostnames: List[str]) -> Optional[Mgmtdomain]: + """Find the corresponding management domain for a pair of + distribution switches. + + Args: + hostnames: A list of one or two hostnames of uplink devices + + Raises: + ValueError: On invalid hostnames etc + Exception: General exceptions + """ + if not isinstance(hostnames, list) or not 1 <= len(hostnames) <= 2: + raise ValueError( + "One or two uplink devices are required to find a compatible mgmtdomain, got: {}".format(hostnames) + ) + for hostname in hostnames: + if not Device.valid_hostname(hostname): + raise ValueError(f"Argument {hostname} is not a valid hostname") + try: + device0: Device = session.query(Device).filter(Device.hostname == hostnames[0]).one() + except NoResultFound: + raise ValueError(f"hostname {hostnames[0]} not found in device database") + + # handle 1 hostname + if len(hostnames) == 1: + return find_mgmtdomain_one_device(session, device0) + + # handle 2 hostnames + try: + device1: Device = session.query(Device).filter(Device.hostname == hostnames[1]).one() + except NoResultFound: + raise ValueError(f"hostname {hostnames[1]} not found in device database") + + return find_mgmtdomain_two_devices(session, device0, device1) + + def find_mgmtdomain_by_ip(session, ipv4_address: IPv4Address) -> Optional[Mgmtdomain]: mgmtdomains = session.query(Mgmtdomain).all() mgmtdom: Mgmtdomain diff --git a/src/cnaas_nms/db/interface.py b/src/cnaas_nms/db/interface.py index 26dcfef8..ac9d318e 100644 --- a/src/cnaas_nms/db/interface.py +++ b/src/cnaas_nms/db/interface.py @@ -1,9 +1,9 @@ import enum import re -from sqlalchemy import Column, Enum, ForeignKey, Integer, Unicode +from sqlalchemy import Enum, ForeignKey, Integer, Unicode from sqlalchemy.dialects.postgresql.json import JSONB -from sqlalchemy.orm import backref, relationship +from sqlalchemy.orm import backref, mapped_column, relationship import cnaas_nms.db.base import cnaas_nms.db.device @@ -38,13 +38,13 @@ def has_name(cls, value): class Interface(cnaas_nms.db.base.Base): __tablename__ = "interface" __table_args__ = (None,) - device_id = Column(Integer, ForeignKey("device.id"), primary_key=True, index=True) + device_id = mapped_column(Integer, ForeignKey("device.id"), primary_key=True, index=True) device = relationship( "Device", foreign_keys=[device_id], backref=backref("Interfaces", cascade="all, delete-orphan") ) - name = Column(Unicode(255), primary_key=True) - configtype = Column(Enum(InterfaceConfigType), nullable=False) - data = Column(JSONB) + name = mapped_column(Unicode(255), primary_key=True) + configtype = mapped_column(Enum(InterfaceConfigType), nullable=False) + data = mapped_column(JSONB) def as_dict(self) -> dict: """Return JSON serializable dict.""" diff --git a/src/cnaas_nms/db/job.py b/src/cnaas_nms/db/job.py index 35299a8e..fa36049b 100644 --- a/src/cnaas_nms/db/job.py +++ b/src/cnaas_nms/db/job.py @@ -5,9 +5,9 @@ from typing import Dict, List, Optional from nornir.core.task import AggregatedResult -from sqlalchemy import Column, DateTime, Enum, ForeignKey, Integer, SmallInteger, Unicode +from sqlalchemy import DateTime, Enum, ForeignKey, Integer, SmallInteger, Unicode from sqlalchemy.dialects.postgresql.json import JSONB -from sqlalchemy.orm import relationship +from sqlalchemy.orm import mapped_column, relationship import cnaas_nms.db.base import cnaas_nms.db.device @@ -49,22 +49,22 @@ def has_name(cls, value): class Job(cnaas_nms.db.base.Base): __tablename__ = "job" __table_args__ = (None,) - id = Column(Integer, autoincrement=True, primary_key=True) - status = Column(Enum(JobStatus), index=True, default=JobStatus.SCHEDULED) - scheduled_time = Column(DateTime, default=datetime.datetime.utcnow) - start_time = Column(DateTime) - finish_time = Column(DateTime, index=True) - function_name = Column(Unicode(255)) - scheduled_by = Column(Unicode(255)) - comment = Column(Unicode(255)) - ticket_ref = Column(Unicode(32), index=True) - next_job_id = Column(Integer, ForeignKey("job.id")) + id = mapped_column(Integer, autoincrement=True, primary_key=True) + status = mapped_column(Enum(JobStatus), index=True, default=JobStatus.SCHEDULED) + scheduled_time = mapped_column(DateTime, default=datetime.datetime.utcnow) + start_time = mapped_column(DateTime) + finish_time = mapped_column(DateTime, index=True) + function_name = mapped_column(Unicode(255)) + scheduled_by = mapped_column(Unicode(255)) + comment = mapped_column(Unicode(255)) + ticket_ref = mapped_column(Unicode(32), index=True) + next_job_id = mapped_column(Integer, ForeignKey("job.id")) next_job = relationship("Job", remote_side=[id]) - result = Column(JSONB) - exception = Column(JSONB) - finished_devices = Column(JSONB) - change_score = Column(SmallInteger) # should be in range 0-100 - start_arguments = Column(JSONB) + result = mapped_column(JSONB) + exception = mapped_column(JSONB) + finished_devices = mapped_column(JSONB) + change_score = mapped_column(SmallInteger) # should be in range 0-100 + start_arguments = mapped_column(JSONB) def as_dict(self) -> dict: """Return JSON serializable dict.""" @@ -82,8 +82,8 @@ def as_dict(self) -> dict: d[col.name] = value return d - def start_job(self, function_name: Optional[str] = None, scheduled_by: Optional[str] = None): - self.start_time = datetime.datetime.utcnow() + def start_job(self, function_name: Optional[str] = None, scheduled_by: str = ""): + self.start_time = datetime.datetime.utcnow() # type: ignore self.status = JobStatus.RUNNING self.finished_devices = [] if function_name: @@ -119,7 +119,7 @@ def finish_success(self, res: dict, next_job_id: Optional[int]): ) self.result = {"error": "unserializable"} - self.finish_time = datetime.datetime.utcnow() + self.finish_time = datetime.datetime.utcnow() # type: ignore if self.status == JobStatus.ABORTING: self.status = JobStatus.ABORTED else: @@ -136,13 +136,13 @@ def finish_success(self, res: dict, next_job_id: Optional[int]): except Exception: # noqa: S110 pass - def finish_exception(self, e: Exception, traceback: str): - logger.warning("Job {} finished with exception: {}".format(self.id, str(e))) - self.finish_time = datetime.datetime.utcnow() + def finish_exception(self, exc: Exception, traceback: str): + logger.warning("Job {} finished with exception: {}".format(self.id, str(exc))) + self.finish_time = datetime.datetime.utcnow() # type: ignore self.status = JobStatus.EXCEPTION try: self.exception = json.dumps( - {"message": str(e), "type": type(e).__name__, "args": e.args, "traceback": traceback}, + {"message": str(exc), "type": type(exc).__name__, "args": exc.args, "traceback": traceback}, default=json_dumper, ) except Exception as e: @@ -154,7 +154,7 @@ def finish_exception(self, e: Exception, traceback: str): { "job_id": self.id, "status": "EXCEPTION", - "exception": str(e), + "exception": str(exc), } ) add_event(json_data=json_data, event_type="update", update_type="job") @@ -163,7 +163,7 @@ def finish_exception(self, e: Exception, traceback: str): def finish_abort(self, message: str): logger.debug("Job {} aborted: {}".format(self.id, message)) - self.finish_time = datetime.datetime.utcnow() + self.finish_time = datetime.datetime.utcnow() # type: ignore self.status = JobStatus.ABORTED self.result = {"message": message} try: @@ -188,13 +188,12 @@ def clear_jobs(cls, session): job.status = JobStatus.ABORTED aborting_jobs = session.query(Job).filter(Job.status == JobStatus.ABORTING).all() - job: Job for job in aborting_jobs: logger.warning("Job found in unfinished ABORTING state at startup moved to ABORTED, id: {}".format(job.id)) job.status = JobStatus.ABORTED scheduled_jobs = session.query(Job).filter(Job.status == JobStatus.SCHEDULED).all() - job: Job + for job in scheduled_jobs: # Clear jobs that should have been run in the past, timing might need tuning if # APschedulers misfire_grace_time is modified diff --git a/src/cnaas_nms/db/joblock.py b/src/cnaas_nms/db/joblock.py index 99cb8cdb..c4334af2 100644 --- a/src/cnaas_nms/db/joblock.py +++ b/src/cnaas_nms/db/joblock.py @@ -1,12 +1,11 @@ import datetime from typing import Dict, Optional -from sqlalchemy import Boolean, Column, DateTime, ForeignKey, Integer, String -from sqlalchemy.exc import DBAPIError -from sqlalchemy.orm import relationship +from sqlalchemy import Boolean, DateTime, ForeignKey, Integer, String +from sqlalchemy.exc import DBAPIError, ProgrammingError +from sqlalchemy.orm import mapped_column, relationship import cnaas_nms.db.base -from cnaas_nms.db.session import sqla_session class JoblockError(Exception): @@ -15,11 +14,11 @@ class JoblockError(Exception): class Joblock(cnaas_nms.db.base.Base): __tablename__ = "joblock" - job_id = Column(Integer, ForeignKey("job.id"), unique=True, primary_key=True) + job_id = mapped_column(Integer, ForeignKey("job.id"), unique=True, primary_key=True) job = relationship("Job", foreign_keys=[job_id]) - name = Column(String(32), unique=True, nullable=False) - start_time = Column(DateTime, default=datetime.datetime.now) # onupdate=now - abort = Column(Boolean, default=False) + name = mapped_column(String(32), unique=True, nullable=False) + start_time = mapped_column(DateTime, default=datetime.datetime.now) # onupdate=now + abort = mapped_column(Boolean, default=False) def as_dict(self) -> dict: """Return JSON serializable dict.""" @@ -34,7 +33,7 @@ def as_dict(self) -> dict: return d @classmethod - def acquire_lock(cls, session: sqla_session, name: str, job_id: int) -> bool: + def acquire_lock(cls, session, name: str, job_id: int) -> bool: # type: ignore curlock = session.query(Joblock).filter(Joblock.name == name).one_or_none() if curlock: return False @@ -44,7 +43,7 @@ def acquire_lock(cls, session: sqla_session, name: str, job_id: int) -> bool: return True @classmethod - def release_lock(cls, session: sqla_session, name: Optional[str] = None, job_id: Optional[int] = None): + def release_lock(cls, session, name: Optional[str] = None, job_id: Optional[int] = None): if job_id: curlock = session.query(Joblock).filter(Joblock.job_id == job_id).one_or_none() elif name: @@ -61,7 +60,7 @@ def release_lock(cls, session: sqla_session, name: Optional[str] = None, job_id: @classmethod def get_lock( - cls, session: sqla_session, name: Optional[str] = None, job_id: Optional[int] = None + cls, session, name: Optional[str] = None, job_id: Optional[int] = None # type: ignore ) -> Optional[Dict[str, str]]: """ @@ -78,7 +77,7 @@ def get_lock( if job_id: curlock: Joblock = session.query(Joblock).filter(Joblock.job_id == job_id).one_or_none() elif name: - curlock: Joblock = session.query(Joblock).filter(Joblock.name == name).one_or_none() + curlock = session.query(Joblock).filter(Joblock.name == name).one_or_none() else: raise ValueError("Either name or jobid must be set to release lock") @@ -88,12 +87,11 @@ def get_lock( return None @classmethod - def clear_locks(cls, session: sqla_session): + def clear_locks(cls, session): """Clear/release all locks in the database.""" try: return session.query(Joblock).delete() - except DBAPIError as e: - if e.orig.pgcode == "42P01": - raise JoblockError("Jobblock table doesn't exist yet, we assume it will be created soon.") - else: - raise + except ProgrammingError: + raise JoblockError("Jobblock table doesn't exist yet, we assume it will be created soon.") + except DBAPIError: + raise diff --git a/src/cnaas_nms/db/linknet.py b/src/cnaas_nms/db/linknet.py index 3564d5e7..589b8fd1 100644 --- a/src/cnaas_nms/db/linknet.py +++ b/src/cnaas_nms/db/linknet.py @@ -3,8 +3,8 @@ import ipaddress from typing import List, Optional -from sqlalchemy import Column, ForeignKey, Integer, Unicode, UniqueConstraint -from sqlalchemy.orm import backref, relationship +from sqlalchemy import ForeignKey, Integer, Unicode, UniqueConstraint +from sqlalchemy.orm import backref, mapped_column, relationship from sqlalchemy_utils import IPAddressType import cnaas_nms.db.base @@ -20,23 +20,23 @@ class Linknet(cnaas_nms.db.base.Base): UniqueConstraint("device_a_id", "device_a_port"), UniqueConstraint("device_b_id", "device_b_port"), ) - id = Column(Integer, autoincrement=True, primary_key=True) - ipv4_network = Column(Unicode(18)) - device_a_id = Column(Integer, ForeignKey("device.id")) + id = mapped_column(Integer, autoincrement=True, primary_key=True) + ipv4_network = mapped_column(Unicode(18)) + device_a_id = mapped_column(Integer, ForeignKey("device.id")) device_a = relationship( "Device", foreign_keys=[device_a_id], backref=backref("linknets_a", cascade="all, delete-orphan") ) - device_a_ip = Column(IPAddressType) - device_a_port = Column(Unicode(64)) - device_b_id = Column(Integer, ForeignKey("device.id")) + device_a_ip = mapped_column(IPAddressType) + device_a_port = mapped_column(Unicode(64)) + device_b_id = mapped_column(Integer, ForeignKey("device.id")) device_b = relationship( "Device", foreign_keys=[device_b_id], backref=backref("linknets_b", cascade="all, delete-orphan") ) - device_b_ip = Column(IPAddressType) - device_b_port = Column(Unicode(64)) - site_id = Column(Integer, ForeignKey("site.id")) + device_b_ip = mapped_column(IPAddressType) + device_b_port = mapped_column(Unicode(64)) + site_id = mapped_column(Integer, ForeignKey("site.id")) site = relationship("Site") - description = Column(Unicode(255)) + description = mapped_column(Unicode(255)) def as_dict(self): """Return JSON serializable dict.""" diff --git a/src/cnaas_nms/db/mgmtdomain.py b/src/cnaas_nms/db/mgmtdomain.py index a00c0401..39bcdf21 100644 --- a/src/cnaas_nms/db/mgmtdomain.py +++ b/src/cnaas_nms/db/mgmtdomain.py @@ -3,10 +3,10 @@ import ipaddress from ipaddress import IPv4Address, IPv6Address, ip_interface from itertools import dropwhile, islice -from typing import Optional, Set, Union +from typing import List, Optional, Set, Union -from sqlalchemy import Column, ForeignKey, Integer, String, Unicode, UniqueConstraint -from sqlalchemy.orm import load_only, relationship +from sqlalchemy import ForeignKey, Integer, String, Unicode, UniqueConstraint +from sqlalchemy.orm import load_only, mapped_column, relationship from sqlalchemy_utils import IPAddressType import cnaas_nms.db.base @@ -25,20 +25,20 @@ class Mgmtdomain(cnaas_nms.db.base.Base): None, UniqueConstraint("device_a_id", "device_b_id"), ) - id = Column(Integer, autoincrement=True, primary_key=True) - ipv4_gw = Column(Unicode(18)) # 255.255.255.255/32 - ipv6_gw = Column(Unicode(43)) # fe80:0000:0000:0000:0000:0000:0000:0000/128 - device_a_id = Column(Integer, ForeignKey("device.id")) + id = mapped_column(Integer, autoincrement=True, primary_key=True) + ipv4_gw = mapped_column(Unicode(18)) # 255.255.255.255/32 + ipv6_gw = mapped_column(Unicode(43)) # fe80:0000:0000:0000:0000:0000:0000:0000/128 + device_a_id = mapped_column(Integer, ForeignKey("device.id")) device_a = relationship("Device", foreign_keys=[device_a_id]) - device_a_ip = Column(IPAddressType) - device_b_id = Column(Integer, ForeignKey("device.id")) + device_a_ip = mapped_column(IPAddressType) + device_b_id = mapped_column(Integer, ForeignKey("device.id")) device_b = relationship("Device", foreign_keys=[device_b_id]) - device_b_ip = Column(IPAddressType) - site_id = Column(Integer, ForeignKey("site.id")) + device_b_ip = mapped_column(IPAddressType) + site_id = mapped_column(Integer, ForeignKey("site.id")) site = relationship("Site") - vlan = Column(Integer) - description = Column(Unicode(255)) - esi_mac = Column(String(12)) + vlan = mapped_column(Integer) + description = mapped_column(Unicode(255)) + esi_mac = mapped_column(String(12)) def as_dict(self): """Return JSON serializable dict.""" @@ -67,13 +67,13 @@ def is_dual_stack(self) -> bool: return bool(self.ipv4_gw) and bool(self.ipv6_gw) @property - def primary_gw(self) -> Optional[str]: + def primary_gw(self) -> str: """Returns the primary gateway interface for this Mgmtdomain, depending on the configured preference""" primary_version = api_settings.MGMTDOMAIN_PRIMARY_IP_VERSION return self.ipv4_gw if primary_version == 4 else self.ipv6_gw @property - def secondary_gw(self) -> Optional[str]: + def secondary_gw(self) -> str: """Returns the secondary gateway interface for this Mgmtdomain, depending on the configured preference""" primary_version = api_settings.MGMTDOMAIN_PRIMARY_IP_VERSION return self.ipv6_gw if primary_version == 4 else self.ipv4_gw @@ -116,8 +116,8 @@ def is_taken(addr): else: mgmt_net = ip_interface(intf_addr).network candidates = islice(mgmt_net.hosts(), api_settings.MGMTDOMAIN_RESERVED_COUNT, None) - free_ips = dropwhile(is_taken, candidates) - return next(free_ips, None) + free_ips: List[IPAddress] = dropwhile(is_taken, candidates) # type: ignore + return next(free_ips, None) # type: ignore @staticmethod def _get_taken_ips(session) -> Set[IPAddress]: diff --git a/src/cnaas_nms/db/reservedip.py b/src/cnaas_nms/db/reservedip.py index 7027d66b..ae924e80 100644 --- a/src/cnaas_nms/db/reservedip.py +++ b/src/cnaas_nms/db/reservedip.py @@ -39,6 +39,8 @@ def clean_reservations( ): rip: Optional[ReservedIP] = None for rip in session.query(ReservedIP): + if not rip: + continue if device and rip.device == device: logger.debug("Clearing reservation of ip {} for device {}".format(rip.ip, device.hostname)) session.delete(rip) diff --git a/src/cnaas_nms/db/session.py b/src/cnaas_nms/db/session.py index 165e21a3..b6b04368 100644 --- a/src/cnaas_nms/db/session.py +++ b/src/cnaas_nms/db/session.py @@ -20,7 +20,7 @@ def _get_session(): @contextmanager -def sqla_session(**kwargs) -> sessionmaker: +def sqla_session(**kwargs): session = _get_session() try: yield session @@ -42,7 +42,7 @@ def sqla_execute(**kwargs): @contextmanager -def redis_session(**kwargs) -> StrictRedis: +def redis_session(**kwargs): with StrictRedis( host=app_settings.REDIS_HOSTNAME, port=app_settings.REDIS_PORT, encoding="utf-8", decode_responses=True ) as conn: diff --git a/src/cnaas_nms/db/settings.py b/src/cnaas_nms/db/settings.py index c6354f85..a43a76af 100644 --- a/src/cnaas_nms/db/settings.py +++ b/src/cnaas_nms/db/settings.py @@ -58,7 +58,7 @@ class VlanConflictError(Exception): DIR_STRUCTURE_HOST = {"base_system.yml": "file", "interfaces.yml": "file", "routing.yml": "file"} -DIR_STRUCTURE = { +DIR_STRUCTURE: dict[str, Any] = { "global": {"base_system.yml": "file", "groups.yml": "file", "routing.yml": "file", "vxlans.yml": "file"}, "fabric": {"base_system.yml": "file"}, "core": {"base_system.yml": "file"}, @@ -84,7 +84,7 @@ def get_model_specific_configfiles(only_modelname: bool = False) -> dict: 'DIST': ['interfaces_veos.yml'] } """ - ret = {"CORE": [], "DIST": []} + ret: dict[str, List[str]] = {"CORE": [], "DIST": []} local_repo_path = app_settings.SETTINGS_LOCAL for devtype in ["CORE", "DIST"]: @@ -204,7 +204,7 @@ def get_pydantic_error_value(data: dict, loc: tuple): def get_pydantic_field_descr(schema: dict, loc: tuple): """Get the description from a pydantic Field definition based on a model schema and a "loc" tuple from pydantic ValidatorError.errors()""" - next_schema = None + next_schema: dict[str, Any] | None = None for loc_part in loc: if next_schema and "$ref" in next_schema: ref_to = next_schema["$ref"].split("/")[2] @@ -216,7 +216,7 @@ def get_pydantic_field_descr(schema: dict, loc: tuple): next_schema = schema["definitions"][next_schema]["properties"][loc_part] else: next_schema = schema["properties"][loc_part] - if "description" in next_schema: + if next_schema and "description" in next_schema: return next_schema["description"] else: return None @@ -287,7 +287,7 @@ def check_settings_collisions(unique_vlans: bool = True): logger = get_logger() mgmt_vlans: Set[int] = set() devices_dict: dict[str, dict] = {} - with sqla_session() as session: + with sqla_session() as session: # type: ignore mgmtdoms = session.query(Mgmtdomain).all() for mgmtdom in mgmtdoms: if mgmtdom.vlan and isinstance(mgmtdom.vlan, int): @@ -329,7 +329,7 @@ def check_vlan_collisions(devices_dict: Dict[str, dict], mgmt_vlans: Set[int], u device_vlan_ids: dict[str, Set[int]] = {} # save used VLAN IDs per device device_vlan_names: dict[str, Set[str]] = {} # save used VLAN names per device access_hostnames: List[str] = [] - with sqla_session() as session: + with sqla_session() as session: # type: ignore access_devs = session.query(Device).filter(Device.device_type == DeviceType.ACCESS).all() for dev in access_devs: access_hostnames.append(dev.hostname) @@ -428,8 +428,8 @@ def read_settings( origin: str, merged_settings, merged_settings_origin, - groups: List[str] = None, - hostname: str = None, + groups: List[str] | None = None, + hostname: str | None = None, ) -> Tuple[dict, dict]: """ @@ -457,11 +457,22 @@ def read_settings( if groups or hostname: syntax_dict, syntax_dict_origin = merge_dict_origin({}, settings, {}, origin) check_settings_syntax(syntax_dict, syntax_dict_origin) - settings = filter_yamldata(settings, groups, hostname) + settings = filter_yamldata(settings, groups if groups else [], hostname if hostname else "") return merge_dict_origin(merged_settings, settings, merged_settings_origin, origin) -def filter_yamldata(data: Union[List, dict], groups: List[str], hostname: str, recdepth=100) -> Union[List, dict]: +def filter_yamldata(data: Union[List, dict], groups: List[str], hostname: str) -> dict: + logger = get_logger() + filtered_yaml_data = recursive_filter_yamldata(data, groups, hostname) + if not isinstance(filtered_yaml_data, dict): + logger.info("Invalid yaml file ignored") + return {} + return filtered_yaml_data + + +def recursive_filter_yamldata( + data: Union[List, dict], groups: List[str], hostname: str, recdepth=100 +) -> Union[List, dict, None]: """Filter data and remove dictionary items if they have a key that specifies a list of groups, but none of those groups are included in the groups argument. Should only be called with yaml.safe_load:ed data. @@ -480,7 +491,7 @@ def filter_yamldata(data: Union[List, dict], groups: List[str], hostname: str, r elif isinstance(data, list): ret_l = [] for item in data: - f_item = filter_yamldata(item, groups, hostname, recdepth - 1) + f_item = recursive_filter_yamldata(item, groups, hostname, recdepth - 1) if f_item: ret_l.append(f_item) return ret_l @@ -490,33 +501,33 @@ def filter_yamldata(data: Union[List, dict], groups: List[str], hostname: str, r hostname_match = False do_filter_group = False do_filter_hostname = False - for k, v in data.items(): - if not v: - ret_d[k] = v + for key, value in data.items(): + if not value: + ret_d[key] = value continue - if k == "groups": - if not isinstance(v, list): # Should already be checked by pydantic now + if key == "groups": + if not isinstance(value, list): # Should already be checked by pydantic now raise SettingsSyntaxError( - "Groups field must be a list or empty (currently {}) in: {}".format(type(v).__name__, data) + "Groups field must be a list or empty (currently {}) in: {}".format(type(value).__name__, data) ) do_filter_group = True - ret_d[k] = v - for group in v: + ret_d[key] = value + for group in value: if group in groups: group_match = True - elif k == "devices": - if not isinstance(v, list): # Should already be checked by pydantic now + elif key == "devices": + if not isinstance(value, list): # Should already be checked by pydantic now raise SettingsSyntaxError( - "Devices field must be a list or empty (currently {}) in: {}".format(type(v).__name__, data) + "Devices field must be a list or empty (currently {}) in: {}".format(type(value).__name__, data) ) do_filter_hostname = True - ret_d[k] = v - if hostname in v: + ret_d[key] = value + if hostname in value: hostname_match = True else: - ret_v = filter_yamldata(v, groups, hostname, recdepth - 1) + ret_v = recursive_filter_yamldata(value, groups, hostname, recdepth - 1) if ret_v: - ret_d[k] = ret_v + ret_d[key] = ret_v if (do_filter_group or do_filter_hostname) and not group_match and not hostname_match: return None else: @@ -526,7 +537,7 @@ def filter_yamldata(data: Union[List, dict], groups: List[str], hostname: str, r def get_downstream_dependencies(hostname: str, settings: dict) -> dict: - with sqla_session() as session: + with sqla_session() as session: # type: ignore dev: Device = session.query(Device).filter(Device.hostname == hostname).one_or_none() if not dev: return settings @@ -736,7 +747,7 @@ def get_group_settings() -> Tuple[dict, dict]: @redis_lru_cache def get_groups(hostname: Optional[str] = None) -> List[str]: """Return list of names for valid groups.""" - groups = [] + groups: list[str] = [] settings, origin = get_group_settings() if not settings: return groups @@ -779,14 +790,22 @@ def get_group_settings_asdict() -> Dict[str, Dict[str, Any]]: for group in settings["groups"]: if "name" not in group["group"]: continue +<<<<<<< HEAD group_dict[group["group"]["name"]] = group["group"] del group_dict[group["group"]["name"]]["name"] return group_dict +======= + if "regex" not in group["group"]: + continue + if group_name == group["group"]["name"]: + return group["group"]["regex"] + return None +>>>>>>> 52d76c6 (add mypy to pre-commit and solved mypy errors) def get_groups_priorities(hostname: Optional[str] = None, settings: Optional[dict] = None) -> Dict[str, int]: """Return dicts with {name: priority} for groups""" - groups_priorities = {} + groups_priorities: dict[str, Any] = {} if not settings: settings, _ = get_group_settings() @@ -832,7 +851,7 @@ def parse_device_primary_groups() -> Dict[str, str]: """Returns a dict with {hostname: primary_group} from settings""" groups_priorities_sorted = get_groups_priorities_sorted() device_primary_group: Dict[str, str] = {} - with sqla_session() as session: + with sqla_session() as session: # type: ignore devices: List[Device] = session.query(Device).all() for dev in devices: groups = get_groups(dev.hostname) @@ -845,7 +864,7 @@ def update_device_primary_groups(): device_primary_group = parse_device_primary_groups() if not device_primary_group: return - with redis_session() as redis: + with redis_session() as redis: # type: ignore redis.hset("device_primary_group", mapping=device_primary_group) @@ -857,13 +876,13 @@ def get_device_primary_groups(no_cache: bool = False) -> Dict[str, str]: """ logger = get_logger() # update redis if redis is empty - with redis_session() as redis: + with redis_session() as redis: # type: ignore if not redis.exists("device_primary_group"): update_device_primary_groups() if no_cache: update_device_primary_groups() device_primary_group: dict = {} - with redis_session() as redis: + with redis_session() as redis: # type: ignore try: device_primary_group = redis.hgetall("device_primary_group") except Exception as e: @@ -880,7 +899,7 @@ def rebuild_settings_cache() -> None: """ logger = get_logger() logger.debug("Clearing redis-lru cache for settings") - with redis_session() as redis_db: + with redis_session() as redis_db: # type: ignore mem_stats_before = redis_db.memory_stats() cache = RedisLRU(redis_db) cache.clear_all_cache() @@ -904,7 +923,7 @@ def rebuild_settings_cache() -> None: for devtype in test_devtypes: get_settings(device_type=devtype) logger.debug("Rebuilding settings cache for device specific settings") - with sqla_session() as session: + with sqla_session() as session: # type: ignore for hostname in os.listdir(os.path.join(app_settings.SETTINGS_LOCAL, "devices")): hostname_path = os.path.join(app_settings.SETTINGS_LOCAL, "devices", hostname) if not os.path.isdir(hostname_path) or hostname.startswith("."): diff --git a/src/cnaas_nms/db/settings_fields.py b/src/cnaas_nms/db/settings_fields.py index b3943c29..3ce53c5d 100644 --- a/src/cnaas_nms/db/settings_fields.py +++ b/src/cnaas_nms/db/settings_fields.py @@ -1,7 +1,7 @@ from ipaddress import AddressValueError, IPv4Interface from typing import Annotated, Dict, List, Optional, Union -from pydantic import BaseModel, Field, FieldValidationInfo, conint, field_validator +from pydantic import BaseModel, Field, ValidationInfo, field_validator from pydantic.functional_validators import AfterValidator # HOSTNAME_REGEX = r'([a-z0-9-]{1,63}\.?)+' @@ -46,8 +46,9 @@ vxlan_vni_schema = Field(..., gt=0, lt=16777215, description="VXLAN Network Identifier") vrf_id_schema = Field(..., gt=0, lt=65536, description="VRF identifier, integer between 1-65535") mtu_schema = Field(None, ge=68, le=9214, description="MTU (Maximum transmission unit) value between 68-9214") -as_num_schema = Field(None, description="BGP Autonomous System number, 1-4294967295 (asdot notation not supported)") -as_num_type = conint(strict=True, gt=0, lt=4294967296) +as_num_schema = Field( + None, gt=0, lt=4294967296, description="BGP Autonomous System number, 1-4294967295 (asdot notation not supported)" +) IFNAME_REGEX = r"([a-zA-Z0-9\/\.:-])+" ifname_schema = Field(None, pattern=f"^{IFNAME_REGEX}$", description="Interface name") IFNAME_RANGE_REGEX = r"([a-zA-Z0-9\/\.:\-\[\]])+" @@ -167,7 +168,7 @@ class f_interface(BaseModel): @field_validator("ipv4_address") @classmethod - def vrf_required_if_ipv4_address_set(cls, v: str, info: FieldValidationInfo): + def vrf_required_if_ipv4_address_set(cls, v: str, info: ValidationInfo): if v: validate_ipv4_if(v) if "vrf" not in info.data or not info.data["vrf"]: @@ -176,7 +177,7 @@ def vrf_required_if_ipv4_address_set(cls, v: str, info: FieldValidationInfo): class f_vrf(BaseModel): - name: str = None + name: Optional[str] = None vrf_id: int = vrf_id_schema import_route_targets: List[str] = [] export_route_targets: List[str] = [] @@ -224,7 +225,7 @@ class f_extroute_ospfv3(BaseModel): class f_extroute_bgp_neighbor_v4(BaseModel): peer_ipv4: str = ipv4_schema - peer_as: as_num_type = as_num_schema + peer_as: str = as_num_schema route_map_in: str = vlan_name_schema route_map_out: str = vlan_name_schema description: str = "undefined" @@ -241,7 +242,7 @@ class f_extroute_bgp_neighbor_v4(BaseModel): class f_extroute_bgp_neighbor_v6(BaseModel): peer_ipv6: str = ipv6_schema - peer_as: as_num_type = as_num_schema + peer_as: str = as_num_schema route_map_in: str = vlan_name_schema route_map_out: str = vlan_name_schema description: str = "undefined" @@ -258,7 +259,7 @@ class f_extroute_bgp_neighbor_v6(BaseModel): class f_extroute_bgp_vrf(BaseModel): name: str - local_as: as_num_type = as_num_schema + local_as: str = as_num_schema neighbor_v4: List[f_extroute_bgp_neighbor_v4] = [] neighbor_v6: List[f_extroute_bgp_neighbor_v6] = [] cli_append_str: str = "" @@ -275,7 +276,7 @@ class f_internal_vlans(BaseModel): @field_validator("vlan_id_high") @classmethod - def vlan_id_high_greater_than_low(cls, v: int, info: FieldValidationInfo): + def vlan_id_high_greater_than_low(cls, v: int, info: ValidationInfo): if v: if info.data["vlan_id_low"] >= v: raise ValueError("vlan_id_high must be greater than vlan_id_low") @@ -305,7 +306,7 @@ class f_vxlan(BaseModel): @field_validator("ipv4_gw") @classmethod - def vrf_required_if_ipv4_gw_set(cls, v: str, info: FieldValidationInfo): + def vrf_required_if_ipv4_gw_set(cls, v: str, info: ValidationInfo): if v: validate_ipv4_if(v) if "vrf" not in info.data or not info.data["vrf"]: @@ -314,7 +315,7 @@ def vrf_required_if_ipv4_gw_set(cls, v: str, info: FieldValidationInfo): @field_validator("ipv6_gw") @classmethod - def vrf_required_if_ipv6_gw_set(cls, v: str, info: FieldValidationInfo): + def vrf_required_if_ipv6_gw_set(cls, v: str, info: ValidationInfo): if v: if "vrf" not in info.data or not info.data["vrf"]: raise ValueError("VRF is required when specifying ipv6_gw") @@ -325,7 +326,7 @@ class f_underlay(BaseModel): infra_lo_net: str = ipv4_if_schema infra_link_net: str = ipv4_if_schema mgmt_lo_net: str = ipv4_if_schema - bgp_asn: Optional[as_num_type] = as_num_schema + bgp_asn: Optional[str] = as_num_schema class f_user(BaseModel): @@ -376,7 +377,7 @@ class f_root(BaseModel): interfaces: List[f_interface] = [] vrfs: List[f_vrf] = [] vxlans: Dict[str, f_vxlan] = {} - underlay: f_underlay = None + underlay: Optional[f_underlay] = None evpn_peers: List[f_evpn_peer] = [] extroute_static: Optional[f_extroute_static] = None extroute_ospfv3: Optional[f_extroute_ospfv3] = None @@ -402,7 +403,7 @@ class f_group_item(BaseModel): @field_validator("group_priority") @classmethod - def reserved_priority(cls, v: int, info: FieldValidationInfo): + def reserved_priority(cls, v: int, info: ValidationInfo): if v and v == 1 and info.data["name"] != "DEFAULT": raise ValueError("group_priority 1 is reserved for built-in group DEFAULT") return v diff --git a/src/cnaas_nms/db/stackmember.py b/src/cnaas_nms/db/stackmember.py index f4c890d9..4253dd94 100644 --- a/src/cnaas_nms/db/stackmember.py +++ b/src/cnaas_nms/db/stackmember.py @@ -1,5 +1,5 @@ -from sqlalchemy import Column, ForeignKey, Integer, String, UniqueConstraint -from sqlalchemy.orm import relationship +from sqlalchemy import ForeignKey, Integer, String, UniqueConstraint +from sqlalchemy.orm import mapped_column, relationship import cnaas_nms.db.base @@ -10,12 +10,12 @@ class Stackmember(cnaas_nms.db.base.Base): UniqueConstraint("device_id", "member_no"), UniqueConstraint("device_id", "hardware_id"), ) - id = Column(Integer, autoincrement=True, primary_key=True) - device_id = Column(Integer, ForeignKey("device.id"), nullable=False) + id = mapped_column(Integer, autoincrement=True, primary_key=True) + device_id = mapped_column(Integer, ForeignKey("device.id"), nullable=False) device = relationship("Device", back_populates="stack_members") - hardware_id = Column(String(64), nullable=False) - member_no = Column(Integer) - priority = Column(Integer) + hardware_id = mapped_column(String(64), nullable=False) + member_no = mapped_column(Integer) + priority = mapped_column(Integer) def as_dict(self) -> dict: """Return JSON serializable dict.""" diff --git a/src/cnaas_nms/db/tests/test_device.py b/src/cnaas_nms/db/tests/test_device.py index c37ea672..d7ad7953 100644 --- a/src/cnaas_nms/db/tests/test_device.py +++ b/src/cnaas_nms/db/tests/test_device.py @@ -19,7 +19,7 @@ def requirements(self, postgresql): pass def cleandb(self): - with sqla_session() as session: + with sqla_session() as session: # type: ignore for hardware_id in ["FO64534", "FO64535"]: stack = session.query(Stackmember).filter(Stackmember.hardware_id == hardware_id).one_or_none() if stack: @@ -51,7 +51,7 @@ def create_test_device(cls, hostname="unittest"): def test_get_linknets(self): device1 = DeviceTests.create_test_device("test-device1") device2 = DeviceTests.create_test_device("test-device2") - with sqla_session() as session: + with sqla_session() as session: # type: ignore session.add(device1) session.add(device2) test_linknet = Linknet(device_a=device1, device_b=device2) @@ -64,7 +64,7 @@ def test_get_linknets(self): def test_get_links_to(self): device1 = DeviceTests.create_test_device("test-device1") device2 = DeviceTests.create_test_device("test-device2") - with sqla_session() as session: + with sqla_session() as session: # type: ignore session.add(device1) session.add(device2) test_linknet = Linknet(device_a=device1, device_b=device2) @@ -77,7 +77,7 @@ def test_get_links_to(self): def test_get_neighbors(self): device1 = DeviceTests.create_test_device("test-device1") device2 = DeviceTests.create_test_device("test-device2") - with sqla_session() as session: + with sqla_session() as session: # type: ignore session.add(device1) session.add(device2) new_linknet = Linknet(device_a=device1, device_b=device2) @@ -88,7 +88,7 @@ def test_get_neighbors(self): self.assertEqual(set([device1]), device2.get_neighbors(session)) def test_is_stack(self): - with sqla_session() as session: + with sqla_session() as session: # type: ignore new_stack = DeviceTests.create_test_device() session.add(new_stack) session.flush() @@ -103,7 +103,7 @@ def test_is_stack(self): self.assertFalse(new_stack.is_stack(session)) def test_get_stackmembers(self): - with sqla_session() as session: + with sqla_session() as session: # type: ignore new_stack = DeviceTests.create_test_device() session.add(new_stack) session.flush() diff --git a/src/cnaas_nms/db/tests/test_git.py b/src/cnaas_nms/db/tests/test_git.py index 5cc8a3fe..6a54d8bb 100644 --- a/src/cnaas_nms/db/tests/test_git.py +++ b/src/cnaas_nms/db/tests/test_git.py @@ -1,9 +1,10 @@ import unittest +from typing import Set, Tuple import pytest from cnaas_nms.db.device import DeviceType -from cnaas_nms.db.git import RepoType, repo_chekout_working, repo_save_working_commit, template_syncstatus +from cnaas_nms.db.git import RepoType, repo_checkout_working, repo_save_working_commit, template_syncstatus from cnaas_nms.db.session import redis_session @@ -15,17 +16,17 @@ def requirements(self, redis): pass def setUp(self) -> None: - with redis_session() as redis: + with redis_session() as redis: # type: ignore redis.delete("SETTINGS_working_commit") redis.delete("TEMPLATES_working_commit") def tearDown(self) -> None: - with redis_session() as redis: + with redis_session() as redis: # type: ignore redis.delete("SETTINGS_working_commit") redis.delete("TEMPLATES_working_commit") def test_check_unsync(self): - devtypes = template_syncstatus({"eos/access-base.j2"}) + devtypes: Set[Tuple[DeviceType, str]] = template_syncstatus({"eos/access-base.j2"}) for devtype in devtypes: self.assertEqual(type(devtype[0]), DeviceType) self.assertEqual(type(devtype[1]), str) @@ -33,15 +34,15 @@ def test_check_unsync(self): def test_savecommit(self): self.assertFalse( - repo_chekout_working(RepoType.SETTINGS, dry_run=True), "Redis working commit not cleared at setUp" + repo_checkout_working(RepoType.SETTINGS, dry_run=True), "Redis working commit not cleared at setUp" ) self.assertFalse( - repo_chekout_working(RepoType.TEMPLATES, dry_run=True), "Redis working commit not cleared at setUp" + repo_checkout_working(RepoType.TEMPLATES, dry_run=True), "Redis working commit not cleared at setUp" ) repo_save_working_commit(RepoType.SETTINGS, "bd5e1f70f52037e8e2a451b2968a9ca8160a7cba") repo_save_working_commit(RepoType.TEMPLATES, "bd5e1f70f52037e8e2a451b2968a9ca8160a7cba") - self.assertTrue(repo_chekout_working(RepoType.SETTINGS, dry_run=True), "Working commit not saved in redis") - self.assertTrue(repo_chekout_working(RepoType.TEMPLATES, dry_run=True), "Working commit not saved in redis") + self.assertTrue(repo_checkout_working(RepoType.SETTINGS, dry_run=True), "Working commit not saved in redis") + self.assertTrue(repo_checkout_working(RepoType.TEMPLATES, dry_run=True), "Working commit not saved in redis") if __name__ == "__main__": diff --git a/src/cnaas_nms/db/tests/test_mgmtdomain.py b/src/cnaas_nms/db/tests/test_mgmtdomain.py index 8a65c710..208706ba 100644 --- a/src/cnaas_nms/db/tests/test_mgmtdomain.py +++ b/src/cnaas_nms/db/tests/test_mgmtdomain.py @@ -42,7 +42,7 @@ def tearDownClass(cls) -> None: @classmethod def add_mgmtdomain(cls): testdata = cls.get_testdata() - with sqla_session() as session: + with sqla_session() as session: # type: ignore d_a = DeviceTests.create_test_device("mgmtdomaintest1") d_b = DeviceTests.create_test_device("mgmtdomaintest2") session.add(d_a) @@ -57,7 +57,7 @@ def add_mgmtdomain(cls): @classmethod def delete_mgmtdomain(cls): - with sqla_session() as session: + with sqla_session() as session: # type: ignore d_a = session.query(Device).filter(Device.hostname == "mgmtdomaintest1").one() instance = session.query(Mgmtdomain).filter(Mgmtdomain.device_a == d_a).first() session.delete(instance) @@ -67,66 +67,66 @@ def delete_mgmtdomain(cls): session.delete(d_b) def test_find_mgmtdomain_invalid(self): - with sqla_session() as session: + with sqla_session() as session: # type: ignore self.assertRaises(ValueError, cnaas_nms.db.helper.find_mgmtdomain, session, []) self.assertRaises(ValueError, cnaas_nms.db.helper.find_mgmtdomain, session, [1, 2, 3]) def test_find_mgmtdomain_twodist(self): - with sqla_session() as session: + with sqla_session() as session: # type: ignore mgmtdomain = cnaas_nms.db.helper.find_mgmtdomain(session, ["eosdist1", "eosdist2"]) self.assertIsNotNone(mgmtdomain, "No mgmtdomain found for eosdist1 + eosdist2") def test_find_mgmtdomain_onedist(self): - with sqla_session() as session: + with sqla_session() as session: # type: ignore mgmtdomain = cnaas_nms.db.helper.find_mgmtdomain(session, ["eosdist1"]) self.assertIsNotNone(mgmtdomain, "No mgmtdomain found for eosdist1") def test_find_mgmtdomain_oneaccess(self): - with sqla_session() as session: + with sqla_session() as session: # type: ignore mgmtdomain = cnaas_nms.db.helper.find_mgmtdomain(session, ["eosaccess"]) self.assertIsNotNone(mgmtdomain, "No mgmtdomain found for eosaccess") def test_is_dual_stack_should_be_false_for_default_domain(self): - with sqla_session() as session: + with sqla_session() as session: # type: ignore mgmtdomain = session.query(Mgmtdomain).limit(1).one() self.assertFalse(mgmtdomain.is_dual_stack) # domain in test data is not dual stack def test_primary_gw_should_be_ipv4_for_default_domain(self): - with sqla_session() as session: + with sqla_session() as session: # type: ignore mgmtdomain = session.query(Mgmtdomain).limit(1).one() self.assertEqual(mgmtdomain.ipv4_gw, mgmtdomain.primary_gw) def test_find_free_primary_mgmt_ip_should_return_an_ipv4_address(self): - with sqla_session() as session: + with sqla_session() as session: # type: ignore mgmtdomain = session.query(Mgmtdomain).limit(1).one() value = mgmtdomain.find_free_primary_mgmt_ip(session) self.assertTrue(value) self.assertIsInstance(value, IPv4Address) def test_find_free_secondary_mgmt_ip_should_return_none(self): - with sqla_session() as session: + with sqla_session() as session: # type: ignore mgmtdomain = session.query(Mgmtdomain).limit(1).one() value = mgmtdomain.find_free_secondary_mgmt_ip(session) self.assertIsNone(value) # domain in test data has no secondary network def test_find_free_mgmt_ip(self): - with sqla_session() as session: + with sqla_session() as session: # type: ignore mgmtdomain = session.query(Mgmtdomain).limit(1).one() mgmtdomain.find_free_mgmt_ip(session) def test_find_free_mgmt_ip_v6(self): - with sqla_session() as session: + with sqla_session() as session: # type: ignore mgmtdomain = session.query(Mgmtdomain).limit(1).one() mgmtdomain.find_free_mgmt_ip(session, version=6) def test_find_free_mgmt_ip_should_fail_on_invalid_ip_version(self): - with sqla_session() as session: + with sqla_session() as session: # type: ignore mgmtdomain = session.query(Mgmtdomain).limit(1).one() with self.assertRaises(ValueError): mgmtdomain.find_free_mgmt_ip(session, version=42) def test_find_mgmtdomain_by_ip(self): - with sqla_session() as session: + with sqla_session() as session: # type: ignore mgmtdomain = cnaas_nms.db.helper.find_mgmtdomain_by_ip(session, IPv4Address("10.0.6.6")) self.assertEqual(IPv4Interface(mgmtdomain.ipv4_gw).network, IPv4Network("10.0.6.0/24")) diff --git a/src/cnaas_nms/devicehandler/cert.py b/src/cnaas_nms/devicehandler/cert.py index cdac3644..6b957988 100644 --- a/src/cnaas_nms/devicehandler/cert.py +++ b/src/cnaas_nms/devicehandler/cert.py @@ -17,7 +17,7 @@ class CopyError(Exception): pass -def arista_copy_cert(task, job_id: Optional[str] = None) -> str: +def arista_copy_cert(task, job_id: Optional[int] = None) -> str: set_thread_data(job_id) logger = get_logger() @@ -85,7 +85,7 @@ def renew_cert_task(task, job_id: str) -> str: set_thread_data(job_id) logger = get_logger() - with sqla_session() as session: + with sqla_session() as session: # type: ignore dev: Device = session.query(Device).filter(Device.hostname == task.host.name).one_or_none() ip = dev.management_ip if not ip: @@ -112,8 +112,8 @@ def renew_cert_task(task, job_id: str) -> str: def renew_cert( hostname: Optional[str] = None, group: Optional[str] = None, - job_id: Optional[str] = None, - scheduled_by: Optional[str] = None, + job_id: Optional[int] = None, + scheduled_by: str = "", ) -> NornirJobResult: logger = get_logger() nr = cnaas_init() @@ -130,7 +130,7 @@ def renew_cert( supported_platforms = ["eos"] # Make sure we only attempt supported devices for device in device_list: - with sqla_session() as session: + with sqla_session() as session: # type: ignore dev: Device = session.query(Device).filter(Device.hostname == device).one_or_none() if not dev: raise Exception("Could not find device: {}".format(device)) diff --git a/src/cnaas_nms/devicehandler/changescore.py b/src/cnaas_nms/devicehandler/changescore.py index e108a906..7c1e4290 100644 --- a/src/cnaas_nms/devicehandler/changescore.py +++ b/src/cnaas_nms/devicehandler/changescore.py @@ -1,10 +1,11 @@ import re +from typing import List line_start = r"^[+-][ ]*" line_start_remove = r"^-[ ]*" DEFAULT_LINE_SCORE = 1.0 # Stops looking after first match. Only searches a single line at a time. -change_patterns = [ +change_patterns: List[dict[str, str | float | re.Pattern]] = [ {"name": "description", "regex": re.compile(str(line_start + r"description")), "modifier": 0.0}, {"name": "name", "regex": re.compile(str(line_start + r"name")), "modifier": 0.0}, {"name": "comment", "regex": re.compile(str(line_start + r"!")), "modifier": 0.0}, @@ -28,10 +29,10 @@ # TODO: multiline patterns / block-aware config -def calculate_line_score(line: str): +def calculate_line_score(line: str) -> float: for pattern in change_patterns: - if re.match(pattern["regex"], line): - return 1 * pattern["modifier"] + if re.match(str(pattern["regex"]), line): + return float(1) * float(str(pattern["modifier"])) return DEFAULT_LINE_SCORE diff --git a/src/cnaas_nms/devicehandler/erase.py b/src/cnaas_nms/devicehandler/erase.py index 30de381a..79d3e22f 100644 --- a/src/cnaas_nms/devicehandler/erase.py +++ b/src/cnaas_nms/devicehandler/erase.py @@ -50,9 +50,11 @@ def device_erase_task(task, hostname: str, job_id: int) -> str: @job_wrapper -def device_erase(device_id: int = None, job_id: int = None, scheduled_by: Optional[str] = None) -> NornirJobResult: +def device_erase( + device_id: Optional[int] = None, job_id: Optional[int] = None, scheduled_by: str = "" +) -> NornirJobResult: logger = get_logger() - with sqla_session() as session: + with sqla_session() as session: # type: ignore dev: Device = session.query(Device).filter(Device.id == device_id).one_or_none() if dev: hostname = dev.hostname @@ -88,8 +90,8 @@ def device_erase(device_id: int = None, job_id: int = None, scheduled_by: Option logger.error("Factory default failed") if failed_hosts == []: - with sqla_session() as session: - dev: Device = session.query(Device).filter(Device.id == device_id).one_or_none() + with sqla_session() as session: # type: ignore + dev = session.query(Device).filter(Device.id == device_id).one_or_none() remove_sync_events(dev.hostname) try: for nei in dev.get_neighbors(session): diff --git a/src/cnaas_nms/devicehandler/firmware.py b/src/cnaas_nms/devicehandler/firmware.py index 0bd9b6a1..26aaf598 100644 --- a/src/cnaas_nms/devicehandler/firmware.py +++ b/src/cnaas_nms/devicehandler/firmware.py @@ -21,7 +21,7 @@ class FirmwareAlreadyActiveException(Exception): pass -def arista_pre_flight_check(task, job_id: Optional[int] = None) -> str: +def arista_pre_flight_check(task, job_id: Optional[int] = None) -> str: # type: ignore """ NorNir task to do some basic checks before attempting to upgrade a switch. @@ -34,7 +34,7 @@ def arista_pre_flight_check(task, job_id: Optional[int] = None) -> str: """ set_thread_data(job_id) logger = get_logger() - with sqla_session() as session: + with sqla_session() as session: # type: ignore if Job.check_job_abort_status(session, job_id): return "Pre-flight aborted" @@ -75,7 +75,7 @@ def arista_post_flight_check(task, post_waittime: int, scheduled_by: str, job_id logger = get_logger() time.sleep(int(post_waittime)) logger.info("Post-flight check wait ({}s) complete, starting check for {}".format(post_waittime, task.host.name)) - with sqla_session() as session: + with sqla_session() as session: # type: ignore if Job.check_job_abort_status(session, job_id): return "Post-flight aborted" @@ -83,7 +83,7 @@ def arista_post_flight_check(task, post_waittime: int, scheduled_by: str, job_id res = task.run(napalm_get, getters=["facts"]) os_version = res[0].result["facts"]["os_version"] - with sqla_session() as session: + with sqla_session() as session: # type: ignore dev: Device = session.query(Device).filter(Device.hostname == task.host.name).one() prev_os_version = dev.os_version dev.os_version = os_version @@ -94,7 +94,7 @@ def arista_post_flight_check(task, post_waittime: int, scheduled_by: str, job_id dev.confhash = None dev.synchronized = False add_sync_event(task.host.name, "firmware_upgrade", scheduled_by, job_id) - dev.last_seen = datetime.datetime.utcnow() + dev.last_seen = datetime.datetime.utcnow() # type: ignore except Exception as e: logger.exception("Could not update OS version on device {}: {}".format(task.host.name, str(e))) return "Post-flight failed, could not update OS version: {}".format(str(e)) @@ -118,14 +118,14 @@ def arista_firmware_download(task, filename: str, httpd_url: str, job_id: Option """ set_thread_data(job_id) logger = get_logger() - with sqla_session() as session: + with sqla_session() as session: # type: ignore if Job.check_job_abort_status(session, job_id): return "Firmware download aborted" url = httpd_url + "/" + filename try: - with sqla_session() as session: + with sqla_session() as session: # type: ignore dev: Device = session.query(Device).filter(Device.hostname == task.host.name).one_or_none() device_type = dev.device_type @@ -162,7 +162,7 @@ def arista_firmware_download(task, filename: str, httpd_url: str, job_id: Option return "Firmware download done." -def arista_firmware_activate(task, filename: str, job_id: Optional[int] = None) -> str: +def arista_firmware_activate(task, filename: str, job_id: Optional[int] = None) -> str: # type: ignore """ NorNir task to modify the boot config for new firmwares. @@ -177,7 +177,7 @@ def arista_firmware_activate(task, filename: str, job_id: Optional[int] = None) """ set_thread_data(job_id) logger = get_logger() - with sqla_session() as session: + with sqla_session() as session: # type: ignore if Job.check_job_abort_status(session, job_id): return "Firmware activate aborted" @@ -214,7 +214,7 @@ def arista_firmware_activate(task, filename: str, job_id: Optional[int] = None) return "Firmware activate done." -def arista_device_reboot(task, job_id: Optional[int] = None) -> str: +def arista_device_reboot(task, job_id: Optional[int] = None) -> str: # type: ignore """ NorNir task to reboot a single device. @@ -228,7 +228,7 @@ def arista_device_reboot(task, job_id: Optional[int] = None) -> str: """ set_thread_data(job_id) logger = get_logger() - with sqla_session() as session: + with sqla_session() as session: # type: ignore if Job.check_job_abort_status(session, job_id): return "Reboot aborted" @@ -246,7 +246,7 @@ def arista_device_reboot(task, job_id: Optional[int] = None) -> str: def device_upgrade_task( - task, + task, # type: ignore job_id: int, scheduled_by: str, filename: str, @@ -337,9 +337,11 @@ def device_upgrade_task( logger.error("Post-flight check failed for: {}".format(" ".join(res.failed_hosts.keys()))) if job_id: - with redis_session() as db: + with redis_session() as db: # type: ignore db.lpush("finished_devices_" + str(job_id), task.host.name) + return res + @job_wrapper def device_upgrade( @@ -354,7 +356,7 @@ def device_upgrade( post_flight: Optional[bool] = False, post_waittime: Optional[int] = 600, reboot: Optional[bool] = False, - scheduled_by: Optional[str] = None, + scheduled_by: str = "", ) -> NornirJobResult: logger = get_logger() nr = cnaas_init() @@ -374,7 +376,7 @@ def device_upgrade( # Make sure we only upgrade Arista access switches for device in device_list: - with sqla_session() as session: + with sqla_session() as session: # type: ignore dev: Device = session.query(Device).filter(Device.hostname == device).one_or_none() if not dev: raise Exception("Could not find device: {}".format(device)) diff --git a/src/cnaas_nms/devicehandler/get.py b/src/cnaas_nms/devicehandler/get.py index b5a0f457..155237f1 100644 --- a/src/cnaas_nms/devicehandler/get.py +++ b/src/cnaas_nms/devicehandler/get.py @@ -1,6 +1,6 @@ import hashlib import re -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Set from netutils.config import compliance from netutils.lib_mapper import NAPALM_LIB_MAPPER @@ -13,7 +13,6 @@ from cnaas_nms.db.device import Device, DeviceType from cnaas_nms.db.device_vars import expand_interface_settings from cnaas_nms.db.interface import Interface, InterfaceConfigType, InterfaceError -from cnaas_nms.db.session import sqla_session from cnaas_nms.tools.log import get_logger @@ -22,7 +21,7 @@ def get_inventory(): return nr.dict()["inventory"] -def get_running_config(hostname: str) -> Optional[str]: +def get_running_config(hostname: str) -> str: nr = cnaas_nms.devicehandler.nornir_helper.cnaas_init() nr_filtered = nr.filter(name=hostname).filter(managed=True) nr_result = nr_filtered.run(task=napalm_get, getters=["config"]) @@ -32,10 +31,10 @@ def get_running_config(hostname: str) -> Optional[str]: return nr_result[hostname].result["config"]["running"] -def get_running_config_interface(session: sqla_session, hostname: str, interface: str) -> str: +def get_running_config_interface(session, hostname: str, interface: str) -> str: running_config = get_running_config(hostname) dev: Device = session.query(Device).filter(Device.hostname == hostname).one() - os_parser = compliance.parser_map[NAPALM_LIB_MAPPER.get(dev.platform)] + os_parser = compliance.parser_map[str(NAPALM_LIB_MAPPER.get(str(dev.platform)))] config_parsed = os_parser(running_config) ret = [] leading_whitespace: Optional[int] = None @@ -89,7 +88,7 @@ def get_uplinks( session, hostname: str, recheck: bool = False, - neighbors: Optional[List[Device]] = None, + neighbors: Optional[Set[Device]] = None, linknets: Optional[List[dict]] = None, ) -> Dict[str, str]: """Returns dict with mapping of interface -> neighbor hostname""" @@ -122,7 +121,7 @@ def get_uplinks( if not neighbors: neighbors = dev.get_neighbors(session, linknets) - for neighbor_d in neighbors: + for neighbor_d in neighbors: # type: ignore if neighbor_d.device_type == DeviceType.DIST: local_ifs = dev.get_neighbor_ifnames(session, neighbor_d, linknets) # Neighbor interface ifclass is already verified in @@ -162,7 +161,7 @@ def get_uplinks( def get_local_ifnames(local_devid: int, peer_devid: int, linknets: List[dict]) -> List[str]: - ifnames = [] + ifnames: List[str] = [] if not linknets: return ifnames for linknet in linknets: @@ -173,9 +172,7 @@ def get_local_ifnames(local_devid: int, peer_devid: int, linknets: List[dict]) - return ifnames -def get_mlag_ifs( - session, dev: Device, mlag_peer_hostname: str, linknets: Optional[List[dict]] = None -) -> Dict[str, int]: +def get_mlag_ifs(session, dev: Device, mlag_peer_hostname: str, linknets: List[dict] = []) -> Dict[str, int]: """Returns dict with mapping of interface -> neighbor id Return id instead of hostname since mlag peer will change hostname during init""" logger = get_logger() @@ -218,7 +215,7 @@ def get_interfaces_names(hostname: str) -> List[str]: return list(getfacts_task.result["interfaces"].keys()) -def filter_interfaces(iflist, platform=None, include=None): +def filter_interfaces(iflist: List[str], platform=None, include=None) -> List[str]: # TODO: include pattern matching from external configurable file ret = [] junos_phy_r = r"^(ge|xe|et|mge)-([0-9]+\/)+[0-9]+$" diff --git a/src/cnaas_nms/devicehandler/init_device.py b/src/cnaas_nms/devicehandler/init_device.py index bdea8c29..7c24a17b 100644 --- a/src/cnaas_nms/devicehandler/init_device.py +++ b/src/cnaas_nms/devicehandler/init_device.py @@ -1,7 +1,7 @@ import datetime import os from ipaddress import IPv4Address, IPv4Interface, ip_interface -from typing import List, Optional, Union +from typing import Any, List, Optional, Union import napalm.base.exceptions import yaml @@ -74,9 +74,11 @@ def push_base_management(task, device_variables: dict, devtype: DeviceType, job_ task=ztp_device_cert, job_id=job_id, new_hostname=task.host.name, management_ip=device_variables["mgmt_ip"] ) except NornirSubTaskError as e: - copy_res: Result = next(iter([res for res in e.result if res.name == "arista_copy_cert"]), None) + copy_res: MultiResult | None = next(iter([res for res in e.result if res.name == "arista_copy_cert"]), None) if copy_res: - nm_res: Result = next(iter([sres for sres in copy_res if sres.name == "netmiko_file_transfer"]), None) + nm_res: Result | None = next( + iter([sres for sres in copy_res if sres.name == "netmiko_file_transfer"]), None + ) if nm_res and isinstance(nm_res.exception, NMReadTimeout): logger.error("Read timeout while copying cert to device") @@ -121,11 +123,9 @@ def push_base_management(task, device_variables: dict, devtype: DeviceType, job_ raise InitError("Device {} did not commit new base management config".format(task.host.name)) -def pre_init_checks(session, device_id: int, accepted_state: Optional[List[DeviceState]] = None) -> Device: +def pre_init_checks(session, device_id: int, accepted_state: List[DeviceState] = [DeviceState.DISCOVERED]) -> Device: """Find device with device_id and check that it's ready for init, returns Device object or raises exception""" - if not accepted_state: - accepted_state: List[DeviceState] = [DeviceState.DISCOVERED] # Check that we can find device and that it's in the correct state to start init dev: Device = session.query(Device).filter(Device.id == device_id).one_or_none() if not dev: @@ -259,7 +259,7 @@ def pre_init_check_neighbors( # Neighbor was explicitly set -> skip verification of neighbor devtype continue - neighbor_dev: Device = session.query(Device).filter(Device.hostname == neighbor).one_or_none() + neighbor_dev = session.query(Device).filter(Device.hostname == neighbor).one_or_none() if not neighbor_dev: raise NeighborError("Neighbor device {} not found in database".format(neighbor)) if devtype == DeviceType.CORE: @@ -293,7 +293,7 @@ def pre_init_check_neighbors( def pre_init_check_mlag(session, dev, mlag_peer_dev): - intfs: Interface = ( + intfs: List[Interface] = ( session.query(Interface) .filter(Interface.device == dev) .filter(InterfaceConfigType == InterfaceConfigType.MLAG_PEER) @@ -355,12 +355,12 @@ def init_mlag_peer_only( device_id: int, mlag_peer_id: int, mlag_peer_new_hostname: str, - job_id: Optional[str] = None, - scheduled_by: Optional[str] = None, + job_id: Optional[int] = None, + scheduled_by: str = "", ): """Try to initialize second MLAG switch if first succeeded but second failed""" logger = get_logger() - with sqla_session() as session: + with sqla_session() as session: # type: ignore dev: Device = session.query(Device).filter(Device.id == device_id).one_or_none() if not dev: raise ValueError(f"No device with id {device_id} found") @@ -382,16 +382,18 @@ def cleanup_init_step1_result(nrresult: List[Union[Result, MultiResult]]) -> Lis for res in nrresult: # These tasks are supposed to get connection timeouts etc, setting them # to failed=False will keep job history clean and cause less confusion - if res.name in ["Push base management config", "push_base_management", "napalm_get"]: + if res.name in ["Push base management config", "push_base_management", "napalm_get"] and isinstance( + res, Result + ): res.failed = False res.result = "" if res.name == "ztp_device_cert" and not api_settings.VERIFY_TLS_DEVICE: - if type(res) is Result: + if isinstance(res, Result): res.failed = False - elif type(res) is MultiResult: + elif isinstance(res, MultiResult): for sres in res: - if type(sres) is Result: - sres.failed = False + if isinstance(res, Result): + sres.failed = False # bug multiresult.failed is read only return nrresult @@ -402,8 +404,8 @@ def init_access_device_step1( mlag_peer_id: Optional[int] = None, mlag_peer_new_hostname: Optional[str] = None, uplink_hostnames_arg: Optional[List[str]] = [], - job_id: Optional[str] = None, - scheduled_by: Optional[str] = None, + job_id: Optional[int] = None, + scheduled_by: str = "", ) -> NornirJobResult: """Initialize access device for management by CNaaS-NMS. If a MLAG/MC-LAG pair is to be configured both mlag_peer_id and @@ -427,7 +429,7 @@ def init_access_device_step1( ValueError """ logger = get_logger() - with sqla_session() as session: + with sqla_session() as session: # type: ignore try: dev: Device = pre_init_checks(session, device_id) except DeviceStateError as e: @@ -440,7 +442,7 @@ def init_access_device_step1( # If this is the first device in an MLAG pair if mlag_peer_id and mlag_peer_new_hostname: - mlag_peer_dev: Device = pre_init_checks(session, mlag_peer_id) + mlag_peer_dev = pre_init_checks(session, mlag_peer_id) # update linknets using LLDP data linknets_all += update_linknets( session, dev.hostname, DeviceType.ACCESS, mlag_peer_dev=mlag_peer_dev, dry_run=True @@ -449,7 +451,7 @@ def init_access_device_step1( linknets_all += update_linknets( session, mlag_peer_dev.hostname, DeviceType.ACCESS, mlag_peer_dev=dev, dry_run=True ) - linknets = Linknet.deduplicate_linknet_dicts(linknets_all) + linknets: List[Any] = Linknet.deduplicate_linknet_dicts(linknets_all) update_interfacedb_worker( session, dev, @@ -542,15 +544,17 @@ def init_access_device_step1( session.commit() # Populate variables for template rendering - mgmt_gw_ipif = ip_interface(mgmtdomain.primary_gw) - mgmt_variables = { - "mgmt_ipif": str(ip_interface("{}/{}".format(mgmt_ip, mgmt_gw_ipif.network.prefixlen))), - "mgmt_ip": str(mgmt_ip), - "mgmt_prefixlen": int(mgmt_gw_ipif.network.prefixlen), - "mgmt_vlan_id": mgmtdomain.vlan, - "mgmt_gw": mgmt_gw_ipif.ip, - } - if secondary_mgmt_ip: + if mgmtdomain.primary_gw is not None: + mgmt_gw_ipif = ip_interface(mgmtdomain.primary_gw) + + mgmt_variables = { + "mgmt_ipif": str(ip_interface("{}/{}".format(mgmt_ip, mgmt_gw_ipif.network.prefixlen))), + "mgmt_ip": str(mgmt_ip), + "mgmt_prefixlen": int(mgmt_gw_ipif.network.prefixlen), + "mgmt_vlan_id": mgmtdomain.vlan, + "mgmt_gw": mgmt_gw_ipif.ip, + } + if secondary_mgmt_ip and mgmtdomain.secondary_gw: secondary_mgmt_gw_ipif = ip_interface(mgmtdomain.secondary_gw) mgmt_variables.update( { @@ -594,8 +598,8 @@ def init_access_device_step1( iter([res for res in nrresult[hostname] if res.name == "napalm_get"]), None ) - with sqla_session() as session: - dev: Device = session.query(Device).filter(Device.id == device_id).one() + with sqla_session() as session: # type: ignore + dev = session.query(Device).filter(Device.id == device_id).one() # If a get call to the old IP does not fail, it means management IP change did not work # Abort and rollback to initial state before device_init if not napalm_get_oldip_result or not napalm_get_oldip_result.failed: @@ -677,8 +681,8 @@ def init_fabric_device_step1( new_hostname: str, device_type: str, neighbors: Optional[List[str]] = [], - job_id: Optional[str] = None, - scheduled_by: Optional[str] = None, + job_id: Optional[int] = None, + scheduled_by: str = "", ) -> NornirJobResult: """Initialize fabric (CORE/DIST) device for management by CNaaS-NMS. @@ -706,7 +710,7 @@ def init_fabric_device_step1( if devtype not in [DeviceType.CORE, DeviceType.DIST]: raise ValueError("Init fabric device requires device type DIST or CORE") - with sqla_session() as session: + with sqla_session() as session: # type: ignore dev = pre_init_checks(session, device_id) # Test update of linknets using LLDP data @@ -733,8 +737,12 @@ def init_fabric_device_step1( mgmt_ip = cnaas_nms.devicehandler.underlay.find_free_mgmt_lo_ip(session) infra_ip = cnaas_nms.devicehandler.underlay.find_free_infra_ip(session) + if mgmt_ip is None: + ip_version = 0 + else: + ip_version = mgmt_ip.version - reserved_ip = ReservedIP(device=dev, ip=mgmt_ip, ip_version=mgmt_ip.version) + reserved_ip = ReservedIP(device=dev, ip=mgmt_ip, ip_version=ip_version) session.add(reserved_ip) dev.infra_ip = infra_ip session.commit() @@ -773,7 +781,7 @@ def init_fabric_device_step1( task=push_base_management, device_variables=device_variables, devtype=devtype, job_id=job_id ) - with sqla_session() as session: + with sqla_session() as session: # type: ignore dev = session.query(Device).filter(Device.id == device_id).one() dev.management_ip = mgmt_ip dev.state = DeviceState.INIT @@ -832,12 +840,12 @@ def schedule_init_device_step2(device_id: int, iteration: int, scheduled_by: str @job_wrapper def init_device_step2( - device_id: int, iteration: int = -1, job_id: Optional[str] = None, scheduled_by: Optional[str] = None + device_id: int, iteration: int = -1, job_id: Optional[int] = None, scheduled_by: str = "" ) -> NornirJobResult: logger = get_logger() # step4+ in apjob: if success, update management ip and device state, trigger external stuff? - with sqla_session() as session: - dev = session.query(Device).filter(Device.id == device_id).one() + with sqla_session() as session: # type: ignore + dev: Device = session.query(Device).filter(Device.id == device_id).one() if dev.state != DeviceState.INIT: logger.error( "Device with ID {} got to init step2 but is in incorrect state: {}".format(device_id, dev.state.name) @@ -864,15 +872,15 @@ def init_device_step2( if hostname != found_hostname: raise InitError("Newly initialized device presents wrong hostname") - with sqla_session() as session: - dev: Device = session.query(Device).filter(Device.id == device_id).one() + with sqla_session() as session: # type: ignore + dev = session.query(Device).filter(Device.id == device_id).one() dev.state = DeviceState.MANAGED dev.synchronized = False add_sync_event(hostname, "device_init", scheduled_by, job_id) set_facts(dev, facts) management_ip = dev.management_ip dev.dhcp_ip = None - dev.last_seen = datetime.datetime.utcnow() + dev.last_seen = datetime.datetime.utcnow() # type: ignore # Plugin hook: new managed device # Send: hostname , device type , serial , platform , vendor , model , os version @@ -893,7 +901,7 @@ def init_device_step2( return NornirJobResult(nrresult=nrresult) -def schedule_discover_device(ztp_mac: str, dhcp_ip: str, iteration: int, scheduled_by: str) -> Optional[Job]: +def schedule_discover_device(ztp_mac: str, dhcp_ip: str, iteration: int, scheduled_by: str = "") -> Optional[Job]: max_iterations = 3 if 0 < iteration <= max_iterations: scheduler = Scheduler() @@ -910,14 +918,14 @@ def schedule_discover_device(ztp_mac: str, dhcp_ip: str, iteration: int, schedul def set_hostname_task(task, new_hostname: str): local_repo_path = app_settings.TEMPLATES_LOCAL - template_vars = {} # host is already set by nornir + # template_vars = {} # host is already set by nornir r = task.run( task=template_file, name="Generate hostname config", template="hostname.j2", jinja_env=get_jinja_env(f"{local_repo_path}/{task.host.platform}"), path=f"{local_repo_path}/{task.host.platform}", - **template_vars, + # **template_vars, ) task.host["config"] = r.result task.run( @@ -930,11 +938,9 @@ def set_hostname_task(task, new_hostname: str): @job_wrapper -def discover_device( - ztp_mac: str, dhcp_ip: str, iteration: int, job_id: Optional[str] = None, scheduled_by: Optional[str] = None -): +def discover_device(ztp_mac: str, dhcp_ip: str, iteration: int, job_id: Optional[int] = None, scheduled_by: str = ""): logger = get_logger() - with sqla_session() as session: + with sqla_session() as session: # type: ignore dev: Device = session.query(Device).filter(Device.ztp_mac == ztp_mac).one_or_none() if not dev: raise ValueError("Device with ztp_mac {} not found".format(ztp_mac)) @@ -958,14 +964,14 @@ def discover_device( return NornirJobResult(nrresult=nrresult) try: facts = nrresult[hostname][0].result["facts"] - with sqla_session() as session: - dev: Device = session.query(Device).filter(Device.ztp_mac == ztp_mac).one() + with sqla_session() as session: # type: ignore + dev = session.query(Device).filter(Device.ztp_mac == ztp_mac).one() dev.serial = facts["serial_number"][:64] dev.vendor = facts["vendor"][:64] dev.model = facts["model"][:64] dev.os_version = facts["os_version"][:64] dev.state = DeviceState.DISCOVERED - dev.last_seen = datetime.datetime.utcnow() + dev.last_seen = datetime.datetime.utcnow() # type: ignore new_hostname = dev.hostname logger.info( f"Device with ztp_mac {ztp_mac} successfully scanned" diff --git a/src/cnaas_nms/devicehandler/interface_state.py b/src/cnaas_nms/devicehandler/interface_state.py index 05ce01e6..6d27f0ab 100644 --- a/src/cnaas_nms/devicehandler/interface_state.py +++ b/src/cnaas_nms/devicehandler/interface_state.py @@ -27,7 +27,7 @@ def get_interface_states(hostname) -> dict: def pre_bounce_check(hostname: str, interfaces: List[str]): # Check1: Database state - with sqla_session() as session: + with sqla_session() as session: # type: ignore dev: Device = session.query(Device).filter(Device.hostname == hostname).one_or_none() if not dev: raise ValueError(f"Hostname {hostname} not found in database") diff --git a/src/cnaas_nms/devicehandler/nornir_helper.py b/src/cnaas_nms/devicehandler/nornir_helper.py index 9c802328..6a75cccb 100644 --- a/src/cnaas_nms/devicehandler/nornir_helper.py +++ b/src/cnaas_nms/devicehandler/nornir_helper.py @@ -1,7 +1,7 @@ import os from dataclasses import dataclass from functools import lru_cache -from typing import List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union from jinja2 import Environment as JinjaEnvironment from jinja2 import FileSystemLoader @@ -58,7 +58,7 @@ def nr_result_serialize(result: AggregatedResult): if not isinstance(result, AggregatedResult): raise ValueError("result must be of type AggregatedResult") - hosts = {} + hosts: dict[str, Any] = {} for host, multires in result.items(): hosts[host] = {"failed": False, "job_tasks": []} for res in multires: @@ -90,7 +90,7 @@ def inventory_selector( Tuple with: filtered Nornir inventory, total device count selected, list of hostnames that was skipped because of resync=False """ - skipped_devices = [] + skipped_devices: List[str] = [] if hostname: if isinstance(hostname, str): nr_filtered = nr.filter(name=hostname).filter(managed=True) diff --git a/src/cnaas_nms/devicehandler/nornir_plugins/cnaas_inventory.py b/src/cnaas_nms/devicehandler/nornir_plugins/cnaas_inventory.py index c265c386..4f9a914b 100644 --- a/src/cnaas_nms/devicehandler/nornir_plugins/cnaas_inventory.py +++ b/src/cnaas_nms/devicehandler/nornir_plugins/cnaas_inventory.py @@ -74,7 +74,7 @@ def load(self) -> Inventory: groups[group_name] = Group(name=group_name, defaults=defaults) hosts = Hosts() - with cnaas_nms.db.session.sqla_session() as session: + with cnaas_nms.db.session.sqla_session() as session: # typing: ignore instance: Device for instance in session.query(Device): hostname = self._get_management_ip(instance.management_ip, instance.dhcp_ip) diff --git a/src/cnaas_nms/devicehandler/sync_devices.py b/src/cnaas_nms/devicehandler/sync_devices.py index 7a288337..59832eff 100644 --- a/src/cnaas_nms/devicehandler/sync_devices.py +++ b/src/cnaas_nms/devicehandler/sync_devices.py @@ -3,7 +3,7 @@ import time from hashlib import sha256 from ipaddress import IPv4Address, IPv4Interface, ip_interface -from typing import List, Optional, Tuple +from typing import Any, List, Optional, Tuple import yaml from napalm.eos import EOSDriver as NapalmEOSDriver @@ -79,6 +79,7 @@ def resolve_vlanid(vlan_name: str, vxlans: dict) -> Optional[int]: except (KeyError, ValueError) as e: logger.error("Could not resolve VLAN ID for VLAN name {}: {}".format(vlan_name, str(e))) return None + return None def resolve_vlanid_list(vlan_name_list: List[str], vxlans: dict) -> List[int]: @@ -93,8 +94,8 @@ def resolve_vlanid_list(vlan_name_list: List[str], vxlans: dict) -> List[int]: def get_mlag_vars(session, dev: Device) -> dict: - ret = {"mlag_peer": False, "mlag_peer_hostname": None, "mlag_peer_low": None} - mlag_peer: Device = dev.get_mlag_peer(session) + ret: dict[str, Any] = {"mlag_peer": False, "mlag_peer_hostname": None, "mlag_peer_low": None} + mlag_peer: Device | None = dev.get_mlag_peer(session) if not mlag_peer: return ret ret["mlag_peer"] = True @@ -110,12 +111,12 @@ def populate_device_vars( session, dev: Device, ztp_hostname: Optional[str] = None, ztp_devtype: Optional[DeviceType] = None ): logger = get_logger() - device_variables = { + device_variables: dict[str, Any] = { "device_model": dev.model, "device_os_version": dev.os_version, "device_id": dev.id, "hostname": dev.hostname, - "stack_members": [] + "stack_members": [], # 'host' variable is also implicitly added by nornir-jinja2 } @@ -128,12 +129,12 @@ def populate_device_vars( if ztp_hostname: hostname: str = ztp_hostname else: - hostname: str = dev.hostname + hostname = dev.hostname if ztp_devtype: devtype: DeviceType = ztp_devtype elif dev.device_type != DeviceType.UNKNOWN: - devtype: DeviceType = dev.device_type + devtype = dev.device_type else: raise Exception("Can't populate device vars for device type UNKNOWN") @@ -151,10 +152,10 @@ def populate_device_vars( if devtype == DeviceType.ACCESS: if ztp_hostname: - access_device_variables = {"interfaces": []} + access_device_variables: dict[str, Any] = {"interfaces": []} else: - mgmtdomain = cnaas_nms.db.helper.find_mgmtdomain_by_ip(session, dev.management_ip) - if not mgmtdomain: + mgmtdomain = cnaas_nms.db.helper.find_mgmtdomain_by_ip(session, IPv4Address(dev.management_ip)) + if not mgmtdomain or not mgmtdomain.primary_gw: raise Exception( "Could not find appropriate management domain for management_ip: {}".format(dev.management_ip) ) @@ -168,7 +169,7 @@ def populate_device_vars( "mgmt_prefixlen": int(mgmt_gw_ipif.network.prefixlen), "interfaces": [], } - if dev.secondary_management_ip: + if dev.secondary_management_ip and mgmtdomain.secondary_gw: secondary_mgmt_gw_ipif = ip_interface(mgmtdomain.secondary_gw) access_device_variables.update( { @@ -187,31 +188,31 @@ def populate_device_vars( ifname_peer_map = dev.get_linknet_localif_mapping(session) intfs = session.query(Interface).filter(Interface.device == dev).all() - intf: Interface - for intf in intfs: + interface: Interface + for interface in intfs: untagged_vlan: Optional[int] = None tagged_vlan_list: List = [] intfdata: Optional[dict] = None try: - ifindexnum: int = Interface.interface_index_num(intf.name) + ifindexnum: int = Interface.interface_index_num(interface.name) except ValueError: - ifindexnum: int = 0 - if intf.data: - if "untagged_vlan" in intf.data: - untagged_vlan = resolve_vlanid(intf.data["untagged_vlan"], settings["vxlans"]) - if "tagged_vlan_list" in intf.data: - tagged_vlan_list = resolve_vlanid_list(intf.data["tagged_vlan_list"], settings["vxlans"]) - intfdata = dict(intf.data) - if intf.name in ifname_peer_map: + ifindexnum = 0 + if interface.data: + if "untagged_vlan" in interface.data: + untagged_vlan = resolve_vlanid(interface.data["untagged_vlan"], settings["vxlans"]) + if "tagged_vlan_list" in interface.data: + tagged_vlan_list = resolve_vlanid_list(interface.data["tagged_vlan_list"], settings["vxlans"]) + intfdata = dict(interface.data) + if interface.name in ifname_peer_map: if isinstance(intfdata, dict): - intfdata["description"] = ifname_peer_map[intf.name] + intfdata["description"] = ifname_peer_map[interface.name] else: - intfdata = {"description": ifname_peer_map[intf.name]} + intfdata = {"description": ifname_peer_map[interface.name]} access_device_variables["interfaces"].append( { - "name": intf.name, - "ifclass": intf.configtype.name, + "name": interface.name, + "ifclass": interface.configtype.name, "untagged_vlan": untagged_vlan, "tagged_vlan_list": tagged_vlan_list, "data": intfdata, @@ -222,8 +223,8 @@ def populate_device_vars( device_variables = {**device_variables, **access_device_variables, **mlag_vars} elif devtype == DeviceType.DIST or devtype == DeviceType.CORE: infra_ip = dev.infra_ip - asn = generate_asn(infra_ip) - fabric_device_variables = { + asn = generate_asn(IPv4Address(infra_ip)) + fabric_device_variables: dict[str, Any] = { "interfaces": [], "bgp_ipv4_peers": [], "bgp_evpn_peers": [], @@ -254,23 +255,23 @@ def populate_device_vars( "peer_hostname": neighbor_d.hostname, "peer_infra_lo": str(neighbor_d.infra_ip), "peer_ip": str(neighbor_ip), - "peer_asn": generate_asn(neighbor_d.infra_ip), + "peer_asn": generate_asn(IPv4Address(neighbor_d.infra_ip)), } fabric_device_variables["bgp_ipv4_peers"].append( { "peer_hostname": neighbor_d.hostname, "peer_infra_lo": str(neighbor_d.infra_ip), "peer_ip": str(neighbor_ip), - "peer_asn": generate_asn(neighbor_d.infra_ip), + "peer_asn": generate_asn(IPv4Address(neighbor_d.infra_ip)), } ) ifname_peer_map = dev.get_linknet_localif_mapping(session) if "interfaces" in settings and settings["interfaces"]: for intf in expand_interface_settings(settings["interfaces"]): try: - ifindexnum: int = Interface.interface_index_num(intf["name"]) + ifindexnum = Interface.interface_index_num(intf["name"]) except ValueError: - ifindexnum: int = 0 + ifindexnum = 0 if "ifclass" not in intf: continue extra_keys = ["aggregate_id", "enabled", "cli_append_str", "metric", "mtu", "tags"] @@ -342,9 +343,11 @@ def populate_device_vars( "vlan": mgmtdom.vlan, "description": mgmtdom.description, "esi_mac": mgmtdom.esi_mac, - "ipv4_ip": str(mgmtdom.device_a_ip) - if hostname == mgmtdom.device_a.hostname - else str(mgmtdom.device_b_ip), + "ipv4_ip": ( + str(mgmtdom.device_a_ip) + if hostname == mgmtdom.device_a.hostname + else str(mgmtdom.device_b_ip) + ), } ) # populate evpn peers data @@ -355,7 +358,7 @@ def populate_device_vars( { "peer_hostname": neighbor_d.hostname, "peer_infra_lo": str(neighbor_d.infra_ip), - "peer_asn": generate_asn(neighbor_d.infra_ip), + "peer_asn": generate_asn(IPv4Address(neighbor_d.infra_ip)), } ) device_variables = {**device_variables, **fabric_device_variables} @@ -482,7 +485,7 @@ def napalm_confirm_commit(task, job_id: int, prev_job_id: int): n_device.confirm_commit() logger.debug("Commit for job {} confirmed on device {}".format(prev_job_id, task.host.name)) if job_id: - with redis_session() as db: + with redis_session() as db: # type: ignore db.lpush("finished_devices_" + str(job_id), task.host.name) @@ -491,8 +494,8 @@ def push_sync_device( confirm_mode: int, dry_run: bool = True, generate_only: bool = False, - job_id: Optional[str] = None, - scheduled_by: Optional[str] = None, + job_id: Optional[int] = None, + scheduled_by: str = "", ): """ Nornir task to generate config and push to device @@ -511,7 +514,7 @@ def push_sync_device( set_thread_data(job_id) logger = get_logger() hostname = task.host.name - with sqla_session() as session: + with sqla_session() as session: # type: ignore dev: Device = session.query(Device).filter(Device.hostname == hostname).one() template_vars = populate_device_vars(session, dev) platform = dev.platform @@ -610,11 +613,11 @@ def push_sync_device( else: task.host["change_score"] = 0 if job_id: - with redis_session() as db: + with redis_session() as db: # type: ignore db.lpush("finished_devices_" + str(job_id), task.host.name) -def generate_only(hostname: str) -> (str, dict): +def generate_only(hostname: str) -> Tuple[str, dict]: """ Generate configuration for a device and return it as a text string. @@ -662,7 +665,7 @@ def sync_check_hash(task, force=False, job_id=None): set_thread_data(job_id) if force is True: return - with sqla_session() as session: + with sqla_session() as session: # type: ignore stored_hash = Device.get_config_hash(session, task.host.name) if stored_hash is None: return @@ -698,7 +701,7 @@ def update_config_hash(task): logger.exception("Unable to get config hash: {}".format(str(e))) raise e else: - with sqla_session() as session: + with sqla_session() as session: # type: ignore Device.set_config_hash(session, task.host.name, new_config_hash) logger.debug("Config hash for {} updated to {}".format(task.host.name, new_config_hash)) @@ -765,13 +768,13 @@ def confirm_devices( prev_job_id: int, hostnames: List[str], job_id: Optional[int] = None, - scheduled_by: Optional[str] = None, + scheduled_by: str = "", resync: bool = False, ) -> NornirJobResult: logger = get_logger() nr = cnaas_init() - nr_filtered, dev_count, skipped_hostnames = select_devices(nr, hostnames, resync) + nr_filtered, dev_count, skipped_hostnames = select_devices(nr, hostnames, resync=resync) device_list = list(nr_filtered.inventory.hosts.keys()) logger.info("Device(s) selected for commit-confirm ({}): {}".format(dev_count, ", ".join(device_list))) @@ -781,7 +784,7 @@ def confirm_devices( except Exception as e: logger.exception("Exception while confirm-commit devices: {}".format(str(e))) try: - with sqla_session() as session: + with sqla_session() as session: # type: ignore logger.info( "Releasing lock for devices from syncto job: {} (in commit-job {})".format(prev_job_id, job_id) ) @@ -802,7 +805,7 @@ def confirm_devices( dry_run=False, force=False, nr_filtered=nr_filtered, unchanged_hosts=[], failed_hosts=failed_hosts ) - with sqla_session() as session: + with sqla_session() as session: # type: ignore for host, results in nrresult.items(): if host in failed_hosts or len(results) != 1: logger.debug("Setting device as unsync for failed commit-confirm on device {}".format(host)) @@ -811,10 +814,10 @@ def confirm_devices( add_sync_event(host, "commit_confirm_failed", scheduled_by, job_id) dev.confhash = None else: - dev: Device = session.query(Device).filter(Device.hostname == host).one() + dev = session.query(Device).filter(Device.hostname == host).one() dev.synchronized = True remove_sync_events(host) - dev.last_seen = datetime.datetime.utcnow() + dev.last_seen = datetime.datetime.utcnow() # type: ignore logger.info("Releasing lock for devices from syncto job: {} (in commit-job {})".format(prev_job_id, job_id)) Joblock.release_lock(session, job_id=prev_job_id) @@ -831,7 +834,7 @@ def sync_devices( force: bool = False, auto_push: bool = False, job_id: Optional[int] = None, - scheduled_by: Optional[str] = None, + scheduled_by: str = "", resync: bool = False, confirm_mode_override: Optional[int] = None, ) -> NornirJobResult: @@ -872,18 +875,20 @@ def sync_devices( else: if nrresult.failed: # Mark devices as unsynchronized if config hash check failed - with sqla_session() as session: + with sqla_session() as session: # type: ignore session.query(Device).filter(Device.hostname.in_(nrresult.failed_hosts.keys())).update( {Device.synchronized: False}, synchronize_session=False ) raise Exception("Configuration hash check failed for {}".format(" ".join(nrresult.failed_hosts.keys()))) if not dry_run: - with sqla_session() as session: + with sqla_session() as session: # type: ignore logger.info("Trying to acquire lock for devices to run syncto job: {}".format(job_id)) max_attempts = 5 lock_ok: bool = False for i in range(max_attempts): + if not job_id: + continue lock_ok = Joblock.acquire_lock(session, name="devices", job_id=job_id) if lock_ok: break @@ -904,7 +909,7 @@ def sync_devices( logger.exception("Exception while synchronizing devices: {}".format(str(e))) try: if not dry_run: - with sqla_session() as session: + with sqla_session() as session: # type: ignore logger.info("Releasing lock for devices from syncto job: {}".format(job_id)) Joblock.release_lock(session, job_id=job_id) except Exception: @@ -946,25 +951,25 @@ def sync_devices( ) # set devices as synchronized if needed - with sqla_session() as session: + with sqla_session() as session: # type: ignore for hostname in changed_hosts: if dry_run: dev: Device = session.query(Device).filter(Device.hostname == hostname).one() if dev.synchronized: dev.synchronized = False add_sync_event(hostname, "syncto_dryrun", scheduled_by, job_id) - dev.last_seen = datetime.datetime.utcnow() + dev.last_seen = datetime.datetime.utcnow() # type: ignore # if next job will commit, that job will mark synchronized on success elif get_confirm_mode(confirm_mode_override) != 2: - dev: Device = session.query(Device).filter(Device.hostname == hostname).one() + dev = session.query(Device).filter(Device.hostname == hostname).one() dev.synchronized = True remove_sync_events(hostname) - dev.last_seen = datetime.datetime.utcnow() + dev.last_seen = datetime.datetime.utcnow() # type: ignore for hostname in unchanged_hosts: - dev: Device = session.query(Device).filter(Device.hostname == hostname).one() + dev = session.query(Device).filter(Device.hostname == hostname).one() dev.synchronized = True remove_sync_events(hostname) - dev.last_seen = datetime.datetime.utcnow() + dev.last_seen = datetime.datetime.utcnow() # type: ignore if not dry_run and get_confirm_mode(confirm_mode_override) != 2: if failed_hosts and get_confirm_mode(confirm_mode_override) == 1: logger.error( @@ -1030,15 +1035,13 @@ def sync_devices( logger.info(f"Commit-confirm for job id {job_id} scheduled as job id {next_job_id}") # keep this thread running until next_job has finished so the device session is not closed, # causing cancellation of pending commits - with sqla_session() as session: + with sqla_session() as session: # type: ignore Job.wait_for_job_completion(session, next_job_id) return NornirJobResult(nrresult=nrresult, next_job_id=next_job_id, change_score=total_change_score) -def push_static_config( - task, config: str, dry_run: bool = True, job_id: Optional[str] = None, scheduled_by: Optional[str] = None -): +def push_static_config(task, config: str, dry_run: bool = True, job_id: Optional[int] = None, scheduled_by: str = ""): """ Nornir task to push static config to device @@ -1060,7 +1063,7 @@ def push_static_config( @job_wrapper def apply_config( - hostname: str, config: str, dry_run: bool, job_id: Optional[int] = None, scheduled_by: Optional[str] = None + hostname: str, config: str, dry_run: bool, job_id: Optional[int] = None, scheduled_by: str = "" ) -> NornirJobResult: """Apply a static configuration (from backup etc) to a device. @@ -1076,7 +1079,7 @@ def apply_config( """ logger = get_logger() - with sqla_session() as session: + with sqla_session() as session: # type: ignore dev: Device = session.query(Device).filter(Device.hostname == hostname).one_or_none() if not dev: raise Exception("Device {} not found".format(hostname)) @@ -1092,8 +1095,8 @@ def apply_config( logger.exception("Exception in apply_config: {}".format(e)) else: if not dry_run: - with sqla_session() as session: - dev: Device = session.query(Device).filter(Device.hostname == hostname).one_or_none() + with sqla_session() as session: # type: ignore + dev = session.query(Device).filter(Device.hostname == hostname).one_or_none() dev.state = DeviceState.UNMANAGED dev.synchronized = False add_sync_event(hostname, "apply_config", scheduled_by, job_id) diff --git a/src/cnaas_nms/devicehandler/sync_history.py b/src/cnaas_nms/devicehandler/sync_history.py index 3ac9202b..a878af3b 100644 --- a/src/cnaas_nms/devicehandler/sync_history.py +++ b/src/cnaas_nms/devicehandler/sync_history.py @@ -55,7 +55,7 @@ def add_sync_event( if not timestamp: timestamp = time.time() sync_event = SyncEvent(cause, timestamp, by, job_id) - with redis_session() as redis: + with redis_session() as redis: # type: ignore if not redis.exists(REDIS_SYNC_HISTORY_KEYNAME): new_history = SyncHistory(history={hostname: [sync_event]}) redis.hset(REDIS_SYNC_HISTORY_KEYNAME, mapping=new_history.redis_dump()) @@ -83,7 +83,7 @@ def add_sync_event( def get_sync_events(hostnames: Optional[List[str]] = None) -> SyncHistory: ret = SyncHistory(history={}) sync_history = SyncHistory(history={}) - with redis_session() as redis: + with redis_session() as redis: # type: ignore sync_history.redis_load(redis.hgetall(REDIS_SYNC_HISTORY_KEYNAME)) if hostnames: for hostname, events in sync_history.history.items(): @@ -96,5 +96,5 @@ def get_sync_events(hostnames: Optional[List[str]] = None) -> SyncHistory: def remove_sync_events(hostname: str): - with redis_session() as redis: + with redis_session() as redis: # type: ignore redis.hdel(REDIS_SYNC_HISTORY_KEYNAME, hostname) diff --git a/src/cnaas_nms/devicehandler/tests/test_get.py b/src/cnaas_nms/devicehandler/tests/test_get.py index 9b7d103e..3a02baef 100644 --- a/src/cnaas_nms/devicehandler/tests/test_get.py +++ b/src/cnaas_nms/devicehandler/tests/test_get.py @@ -44,7 +44,7 @@ def test_get_inventory(self): self.assertLessEqual(1, len(result["hosts"].items())) def test_get_mlag_ifs(self): - with sqla_session() as session: + with sqla_session() as session: # type: ignore try: dev_a: Device = self.create_test_device(self.testdata["mlag_dev_a"]) dev_b: Device = self.create_test_device(self.testdata["mlag_dev_b"]) @@ -83,7 +83,7 @@ def test_get_mlag_ifs(self): @pytest.mark.equipment def test_update_links(self): - with sqla_session() as session: + with sqla_session() as session: # type: ignore new_links = cnaas_nms.devicehandler.update.update_linknets( session, self.testdata["init_access_new_hostname"], DeviceType.ACCESS ) @@ -91,7 +91,7 @@ def test_update_links(self): @pytest.mark.equipment def test_get_running_config_interface(self): - with sqla_session() as session: + with sqla_session() as session: # type: ignore if_config: str = cnaas_nms.devicehandler.get.get_running_config_interface(session, "eosdist1", "Ethernet1") assert if_config.strip(), "no config found" diff --git a/src/cnaas_nms/devicehandler/tests/test_init.py b/src/cnaas_nms/devicehandler/tests/test_init.py index ff524cf2..a1ad1891 100644 --- a/src/cnaas_nms/devicehandler/tests/test_init.py +++ b/src/cnaas_nms/devicehandler/tests/test_init.py @@ -31,7 +31,7 @@ def tearDown(self): time.sleep(1) for i in range(1, 11): num_scheduled_jobs = len(ap_scheduler.get_jobs()) - with sqla_session() as session: + with sqla_session() as session: # type: ignore num_running_jobs = session.query(Job).count() print( "Number of jobs scheduled: {}, number of jobs running: {}".format(num_scheduled_jobs, num_running_jobs) @@ -83,7 +83,7 @@ def reset_access_device(self): reset_interfacedb(self.testdata["init_access_new_hostname"]) - with sqla_session() as session: + with sqla_session() as session: # type: ignore dev: Device = ( session.query(Device).filter(Device.hostname == self.testdata["init_access_new_hostname"]).one() ) diff --git a/src/cnaas_nms/devicehandler/tests/test_syncto.py b/src/cnaas_nms/devicehandler/tests/test_syncto.py index 0c2c388a..811be19a 100644 --- a/src/cnaas_nms/devicehandler/tests/test_syncto.py +++ b/src/cnaas_nms/devicehandler/tests/test_syncto.py @@ -36,7 +36,7 @@ def run_syncto_job(scheduler, testdata: dict, dry_run: bool = True) -> Optional[ job_res: Optional[Job] = None job_dict: Optional[dict] = None jobstatus_wait = [JobStatus.SCHEDULED, JobStatus.RUNNING] - with sqla_session() as session: + with sqla_session() as session: # type: ignore for i in range(1, 60): time.sleep(1) if not job_res or job_res.status in jobstatus_wait: diff --git a/src/cnaas_nms/devicehandler/tests/test_update.py b/src/cnaas_nms/devicehandler/tests/test_update.py index 257a9e60..085a6fa1 100644 --- a/src/cnaas_nms/devicehandler/tests/test_update.py +++ b/src/cnaas_nms/devicehandler/tests/test_update.py @@ -38,7 +38,7 @@ def get_linknets(self, session, neighbors_data: Optional[dict] = None, hostname: ) def test_update_linknet_eosaccess(self): - with sqla_session() as session: + with sqla_session() as session: # type: ignore linknets = self.get_linknets(session) for ln in linknets: ln["device_a_id"] = None @@ -50,7 +50,7 @@ def test_update_linknet_eosaccess(self): ) def test_update_linknet_eosaccess_nonredundant(self): - with sqla_session() as session: + with sqla_session() as session: # type: ignore linknets = self.get_linknets(session, self.testdata["lldp_data_nonredundant"]) for ln in linknets: ln["device_a_id"] = None @@ -62,7 +62,7 @@ def test_update_linknet_eosaccess_nonredundant(self): ) def test_update_linknet_wrong_porttype(self): - with sqla_session() as session: + with sqla_session() as session: # type: ignore neighbors_data = { "Ethernet2": [{"hostname": "eosdist1", "port": "Ethernet1"}], "Ethernet3": [{"hostname": "eosdist2", "port": "Ethernet1"}], @@ -79,7 +79,7 @@ def test_update_linknet_wrong_porttype(self): ) def test_pre_init_check_access_redundant(self): - with sqla_session() as session: + with sqla_session() as session: # type: ignore linknets = self.get_linknets(session) dev: Device = session.query(Device).filter(Device.hostname == "eosaccess").one() self.assertListEqual( @@ -87,13 +87,13 @@ def test_pre_init_check_access_redundant(self): ) def test_pre_init_check_access_nonredundant(self): - with sqla_session() as session: + with sqla_session() as session: # type: ignore linknets = self.get_linknets(session, self.testdata["lldp_data_nonredundant"]) dev: Device = session.query(Device).filter(Device.hostname == "eosaccess").one() self.assertListEqual(pre_init_check_neighbors(session, dev, DeviceType.ACCESS, linknets), ["eosdist1"]) def test_pre_init_check_access_nonredundant_error(self): - with sqla_session() as session: + with sqla_session() as session: # type: ignore linknets = self.get_linknets(session, self.testdata["lldp_data_nonredundant_error"]) dev: Device = session.query(Device).filter(Device.hostname == "eosaccess").one() self.assertRaises( diff --git a/src/cnaas_nms/devicehandler/update.py b/src/cnaas_nms/devicehandler/update.py index f227d83d..9c4c53d4 100644 --- a/src/cnaas_nms/devicehandler/update.py +++ b/src/cnaas_nms/devicehandler/update.py @@ -31,14 +31,14 @@ def update_interfacedb_worker( replace: bool, delete_all: bool, mlag_peer_hostname: Optional[str] = None, - linknets: Optional[List[dict]] = None, + linknets: List[dict] = [], ) -> List[dict]: """Perform actual work of updating database for update_interfacedb. If replace is set to true, configtype and data will get overwritten. If delete_all is set to true, delete all interfaces from database. Return list of new/updated interfaces, or empty if delete_all was set.""" logger = get_logger() - ret = [] + ret: List[dict] = [] current_iflist = session.query(Interface).filter(Interface.device == dev).all() unmatched_iflist = [] @@ -73,7 +73,7 @@ def update_interfacedb_worker( new_intf = False else: new_intf = True - intf: Interface = Interface() + intf = Interface() if not new_intf and not replace: continue logger.debug("New/updated physical interface found on device {}: {}".format(dev.hostname, intf_name)) @@ -115,8 +115,8 @@ def update_interfacedb( replace: bool = False, delete_all: bool = False, mlag_peer_hostname: Optional[str] = None, - job_id: Optional[str] = None, - scheduled_by: Optional[str] = None, + job_id: Optional[int] = None, + scheduled_by: str = "", ) -> DictJobResult: """Update interface DB with any new physical interfaces for specified device. If replace is set, any existing records in the database will get overwritten. @@ -125,7 +125,7 @@ def update_interfacedb( Returns: List of interfaces that was added to DB """ - with sqla_session() as session: + with sqla_session() as session: # type: ignore dev: Device = session.query(Device).filter(Device.hostname == hostname).one_or_none() if not dev: raise ValueError(f"Hostname {hostname} not found in database") @@ -143,7 +143,7 @@ def update_interfacedb( def reset_interfacedb(hostname: str): - with sqla_session() as session: + with sqla_session() as session: # type: ignore dev: Device = session.query(Device).filter(Device.hostname == hostname).one_or_none() if not dev: raise ValueError(f"Hostname {hostname} not found in database") @@ -173,9 +173,9 @@ def set_facts(dev: Device, facts: dict) -> dict: @job_wrapper -def update_facts(hostname: str, job_id: Optional[str] = None, scheduled_by: Optional[str] = None): +def update_facts(hostname: str, job_id: Optional[int] = None, scheduled_by: str = ""): logger = get_logger() - with sqla_session() as session: + with sqla_session() as session: # type: ignore dev: Device = session.query(Device).filter(Device.hostname == hostname).one_or_none() if not dev: raise ValueError("Device with hostname {} not found".format(hostname)) @@ -193,10 +193,10 @@ def update_facts(hostname: str, job_id: Optional[str] = None, scheduled_by: Opti return NornirJobResult(nrresult=nrresult) try: facts = nrresult[hostname][0].result["facts"] - with sqla_session() as session: - dev: Device = session.query(Device).filter(Device.hostname == hostname).one() + with sqla_session() as session: # type: ignore + dev = session.query(Device).filter(Device.hostname == hostname).one() diff = set_facts(dev, facts) - dev.last_seen = datetime.datetime.utcnow() + dev.last_seen = datetime.datetime.utcnow() # type: ignore logger.debug( "Updating facts for device {}, new values: {}, {}, {}, {}".format( diff --git a/src/cnaas_nms/models/permissions.py b/src/cnaas_nms/models/permissions.py index 1aaace07..27dafbbe 100644 --- a/src/cnaas_nms/models/permissions.py +++ b/src/cnaas_nms/models/permissions.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Any, Dict, Optional from pydantic import BaseModel, model_validator @@ -20,7 +20,7 @@ class RoleModel(BaseModel): class PermissionsModel(BaseModel): config: Optional[PemissionConfig] = None - group_mappings: Optional[Dict[str, Dict[str, list[str]]]] = {} + group_mappings: Optional[Dict[str, Any]] = {} roles: Dict[str, RoleModel] @model_validator(mode="after") @@ -32,6 +32,8 @@ def check_if_default_permissions_role_exist(self) -> "PermissionsModel": @model_validator(mode="after") def check_if_roles_in_mappings_exist(self) -> "PermissionsModel": + if self.group_mappings is None: + return self for map_type in self.group_mappings: for group in self.group_mappings[map_type]: for role_name in self.group_mappings[map_type][group]: diff --git a/src/cnaas_nms/plugins/nav.py b/src/cnaas_nms/plugins/nav.py index 87173d19..cf7d67ca 100644 --- a/src/cnaas_nms/plugins/nav.py +++ b/src/cnaas_nms/plugins/nav.py @@ -7,7 +7,7 @@ class Plugin(CnaasBasePlugin): - def __init__(self): + def __init__(self) -> None: self.urlbase = None self.apitoken = None self.organizationid = "Undefined" @@ -26,7 +26,7 @@ def __init__(self): self.snmp_community = pluginvars["snmp_community"] @hookimpl - def selftest(self): + def selftest(self) -> bool: if self.urlbase and self.apitoken: return True else: @@ -34,7 +34,7 @@ def selftest(self): @hookimpl def new_managed_device(self, hostname, device_type, serial_number, vendor, model, os_version, management_ip): - headers = {"Authorization": "Token " + self.apitoken} + headers = {"Authorization": "Token " + str(self.apitoken)} data = { "ip": management_ip, "sysname": hostname, @@ -44,7 +44,7 @@ def new_managed_device(self, hostname, device_type, serial_number, vendor, model "snmp_version": 2, "read_only": self.snmp_community, } - r = requests.post(self.urlbase + "/api/1/netbox/", headers=headers, json=data) + r = requests.post(str(self.urlbase) + "/api/1/netbox/", headers=headers, json=data) if not r.status_code == 201: logger.warn("Failed to add device to NAV: code {}: {} (data: {})".format(r.status_code, r.text, data)) return False diff --git a/src/cnaas_nms/plugins/ni.py b/src/cnaas_nms/plugins/ni.py index 12c03d80..71b4108c 100644 --- a/src/cnaas_nms/plugins/ni.py +++ b/src/cnaas_nms/plugins/ni.py @@ -1,3 +1,5 @@ +from typing import Any, List + import requests from cnaas_nms.plugins.pluginspec import CnaasBasePlugin, hookimpl @@ -7,7 +9,7 @@ class Plugin(CnaasBasePlugin): - def __init__(self): + def __init__(self) -> None: self.urlbase = None self.apiuser = None self.apitoken = None @@ -22,19 +24,19 @@ def __init__(self): self.apitoken = pluginvars["apitoken"] @hookimpl - def selftest(self): + def selftest(self) -> bool: if self.urlbase and self.apiuser and self.apitoken: return True else: return False @hookimpl - def new_managed_device(self, hostname, device_type, serial_number, vendor, model, os_version, management_ip): + def new_managed_device(self, hostname, device_type, serial_number, vendor, model, os_version, management_ip): # type: ignore headers = {"Authorization": "ApiKey {}:{}".format(self.apiuser, self.apitoken)} - data = {"node": {"operational_state": "In service"}} + data: dict[str, dict[str, Any]] = {"node": {"operational_state": "In service"}} - res = requests.get(self.urlbase, headers=headers, verify=False) + res = requests.get(str(self.urlbase), headers=headers, verify=False) if not res.status_code == 200: logger.warn("Failed to fetch devices from NI: {}: {} ({})".format(res.status_code, res.text, data)) @@ -46,14 +48,14 @@ def new_managed_device(self, hostname, device_type, serial_number, vendor, model if management_ip: if "ip_addresses" in device["node"]: - addresses = device["node"]["ip_addresses"] + addresses: List[str] = device["node"]["ip_addresses"] + addresses.insert(0, management_ip) data["node"]["ip_addresses"] = addresses - data["node"]["ip_addresses"].insert(0, management_ip) else: data["node"]["ip_addresses"] = [management_ip] handle_id = device["handle_id"] - res = requests.put(self.urlbase + str(handle_id) + "/", headers=headers, json=data, verify=False) + res = requests.put(str(self.urlbase) + str(handle_id) + "/", headers=headers, json=data, verify=False) if res.status_code != 204: logger.warn("Could not change device {} with ID {}.".format(hostname, handle_id)) diff --git a/src/cnaas_nms/run.py b/src/cnaas_nms/run.py index a429d796..ef993247 100644 --- a/src/cnaas_nms/run.py +++ b/src/cnaas_nms/run.py @@ -7,7 +7,6 @@ import coverage from gevent import monkey from gevent import signal as gevent_signal -from redis import StrictRedis from cnaas_nms.app_settings import api_settings @@ -62,13 +61,13 @@ def get_app(): pmh.load_plugins() try: - with sqla_session() as session: + with sqla_session() as session: # type: ignore Joblock.clear_locks(session) except Exception as e: print("Unable to clear old locks from database at startup: {}".format(str(e))) # noqa: T001 try: - with sqla_session() as session: + with sqla_session() as session: # type: ignore Job.clear_jobs(session) except Exception as e: print("Unable to clear jobs with invalid states: {}".format(str(e))) # noqa: T001 @@ -93,6 +92,8 @@ def loglevel_to_rooms(levelname: str) -> List[str]: return ["DEBUG", "INFO", "WARNING", "ERROR"] elif levelname == "CRITICAL": return ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] + else: + raise Exception("Invalid levelname given") def parse_redis_event(event): @@ -117,8 +118,7 @@ def emit_redis_event(event): def thread_websocket_events(): - redis: StrictRedis - with redis_session() as redis: + with redis_session() as redis: # type: ignore last_event = b"$" while True: result = redis.xread({"events": last_event}, count=10, block=200) diff --git a/src/cnaas_nms/scheduler/scheduler.py b/src/cnaas_nms/scheduler/scheduler.py index ac4e5cbe..ba05d64f 100644 --- a/src/cnaas_nms/scheduler/scheduler.py +++ b/src/cnaas_nms/scheduler/scheduler.py @@ -143,12 +143,12 @@ def remove_scheduled_job(self, job_id, abort_message="removed"): else: self.remove_local_job(job_id) - with sqla_session() as session: + with sqla_session() as session: # type: ignore job = session.query(Job).filter(Job.id == job_id).one_or_none() job.finish_abort(message=abort_message) def add_onetime_job( - self, func: Union[str, FunctionType], when: Optional[int] = None, scheduled_by: Optional[str] = None, **kwargs + self, func: Union[str, FunctionType], when: Optional[int] = None, scheduled_by: str = "", **kwargs ) -> int: """Schedule a job to run at a later time on the mule worker or local scheduler depending on setup. @@ -191,10 +191,10 @@ def add_onetime_job( except Exception: pass - with sqla_session() as session: + with sqla_session() as session: # type: ignore job = Job() if run_date: - job.scheduled_time = run_date + job.scheduled_time = run_date # type: ignore job.function_name = func_name if scheduled_by is None: scheduled_by = "unknown" diff --git a/src/cnaas_nms/scheduler/tests/test_scheduler.py b/src/cnaas_nms/scheduler/tests/test_scheduler.py index 0b9b78bd..40ab9e98 100644 --- a/src/cnaas_nms/scheduler/tests/test_scheduler.py +++ b/src/cnaas_nms/scheduler/tests/test_scheduler.py @@ -33,7 +33,7 @@ def test_add_schedule(postgresql, scheduler): print(f"Test job 1 scheduled as ID { job1_id }") print(f"Test job 2 scheduled as ID { job2_id }") time.sleep(3) - with sqla_session() as session: + with sqla_session() as session: # type: ignore job1 = session.query(Job).filter(Job.id == job1_id).one_or_none() assert isinstance(job1, Job), "Test job 1 could not be found" assert job1.status == JobStatus.FINISHED, "Test job 1 did not finish" @@ -53,7 +53,7 @@ def test_abort_schedule(postgresql, scheduler): print(f"Test job 3 scheduled as ID { job3_id }") scheduler.remove_scheduled_job(job3_id) time.sleep(3) - with sqla_session() as session: + with sqla_session() as session: # type: ignore job3 = session.query(Job).filter(Job.id == job3_id).one_or_none() assert isinstance(job3, Job), "Test job 3 could not be found" assert job3.status == JobStatus.ABORTED, "Test job 3 did not abort" diff --git a/src/cnaas_nms/scheduler/wrapper.py b/src/cnaas_nms/scheduler/wrapper.py index 75948b7f..73dd0861 100644 --- a/src/cnaas_nms/scheduler/wrapper.py +++ b/src/cnaas_nms/scheduler/wrapper.py @@ -25,13 +25,13 @@ def insert_job_id(result: JobResult, job_id: int) -> JobResult: def update_device_progress(job_id: int): new_finished_devices = [] - with redis_session() as db: - while db.llen("finished_devices_" + str(job_id)) != 0: - last_finished = db.lpop("finished_devices_" + str(job_id)) + with redis_session() as redis: # type: ignore + while redis.llen("finished_devices_" + str(job_id)) != 0: + last_finished = redis.lpop("finished_devices_" + str(job_id)) new_finished_devices.append(last_finished) if new_finished_devices: - with sqla_session() as session: + with sqla_session() as session: # type: ignore job = session.query(Job).filter(Job.id == job_id).one_or_none() if not job: raise ValueError("Could not find Job with ID {}".format(job_id)) @@ -54,7 +54,7 @@ def wrapper(job_id: int, scheduled_by: str, kwargs={}): logger.error(errmsg) raise ValueError(errmsg) progress_funcitons = ["sync_devices", "device_upgrade", "confirm_devices"] - with sqla_session() as session: + with sqla_session() as session: # type: ignore job = session.query(Job).filter(Job.id == job_id).one_or_none() if not job: errmsg = "Could not find job_id {} in database".format(job_id) @@ -80,7 +80,7 @@ def wrapper(job_id: int, scheduled_by: str, kwargs={}): except Exception as e: tb = traceback.format_exc() logger.debug("Exception traceback in job_wrapper: {}".format(tb)) - with sqla_session() as session: + with sqla_session() as session: # type: ignore job = session.query(Job).filter(Job.id == job_id).one_or_none() if not job: errmsg = "Could not find job_id {} in database".format(job_id) @@ -95,7 +95,7 @@ def wrapper(job_id: int, scheduled_by: str, kwargs={}): if func.__name__ in progress_funcitons: stop_event.set() device_thread.join() - with sqla_session() as session: + with sqla_session() as session: # type: ignore job = session.query(Job).filter(Job.id == job_id).one_or_none() if not job: errmsg = "Could not find job_id {} in database".format(job_id) diff --git a/src/cnaas_nms/scheduler_mule.py b/src/cnaas_nms/scheduler_mule.py index fe33b32a..19ca4f6d 100644 --- a/src/cnaas_nms/scheduler_mule.py +++ b/src/cnaas_nms/scheduler_mule.py @@ -26,17 +26,17 @@ def is_coverage_enabled(): cov = coverage.coverage(data_file=".coverage-{}".format(os.getpid())) cov.start() - def save_coverage(): + def save_coverage() -> None: cov.stop() cov.save() - def save_coverage_signal(): + def save_coverage_signal() -> None: cov.stop() cov.save() atexit.register(save_coverage) - signal.signal(signal.SIGTERM, save_coverage_signal) - signal.signal(signal.SIGINT, save_coverage_signal) + signal.signal(signal.SIGTERM, save_coverage_signal) # type: ignore + signal.signal(signal.SIGINT, save_coverage_signal) # type: ignore def pre_schedule_checks(scheduler, kwargs): @@ -53,14 +53,14 @@ def pre_schedule_checks(scheduler, kwargs): if not check_ok: logger.debug(message) - with sqla_session() as session: + with sqla_session() as session: # type: ignore job_entry: Job = session.query(Job).filter(Job.id == kwargs["job_id"]).one_or_none() job_entry.finish_abort(message) return check_ok -def main_loop(): +def main_loop() -> None: try: import uwsgi except Exception as e: @@ -76,7 +76,7 @@ def main_loop(): pmh.load_plugins() try: - with sqla_session() as session: + with sqla_session() as session: # type: ignore Joblock.clear_locks(session) except Exception as e: logger.exception("Unable to clear old locks from database at startup: {}".format(str(e))) diff --git a/src/cnaas_nms/tools/cache.py b/src/cnaas_nms/tools/cache.py index c30bddf9..84f55af4 100644 --- a/src/cnaas_nms/tools/cache.py +++ b/src/cnaas_nms/tools/cache.py @@ -17,7 +17,7 @@ def get_token_info_from_cache(token: Token) -> Optional[dict]: """Check if the userinfo is in the cache to avoid multiple calls to the OIDC server""" try: - with redis_session() as redis: + with redis_session() as redis: # type: ignore cached_token_info = redis.hget(REDIS_OAUTH_TOKEN_INFO_KEY, token.decoded_token["sub"]) if cached_token_info: return json.loads(cached_token_info) @@ -31,7 +31,7 @@ def get_token_info_from_cache(token: Token) -> Optional[dict]: def put_token_info_in_cache(token: Token, token_info) -> bool: """Put the userinfo in the cache to avoid multiple calls to the OIDC server""" try: - with redis_session() as redis: + with redis_session() as redis: # type: ignore if "exp" in token.decoded_token: redis.hsetnx(REDIS_OAUTH_TOKEN_INFO_KEY, token.decoded_token["sub"], token_info) # expire hash at access_token expiry time or 1 hour from now (whichever is sooner) diff --git a/src/cnaas_nms/tools/dhcp_hook.py b/src/cnaas_nms/tools/dhcp_hook.py index a4418d9e..b01213d4 100644 --- a/src/cnaas_nms/tools/dhcp_hook.py +++ b/src/cnaas_nms/tools/dhcp_hook.py @@ -3,6 +3,7 @@ import logging import os import sys +from typing import Union import netaddr import requests @@ -54,7 +55,7 @@ def canonical_mac(mac): return str(na_mac) -def main() -> int: +def main() -> Union[int, None]: if len(sys.argv) < 3: return 1 @@ -175,6 +176,7 @@ def main() -> int: r_data["data"]["added_device"]["hostname"], r_data["data"]["added_device"]["id"] ) ) + return None if __name__ == "__main__": diff --git a/src/cnaas_nms/tools/dropdb.py b/src/cnaas_nms/tools/dropdb.py index 75c4308a..30337b80 100644 --- a/src/cnaas_nms/tools/dropdb.py +++ b/src/cnaas_nms/tools/dropdb.py @@ -24,4 +24,4 @@ sys.exit(0) -print(Base.metadata.drop_all(engine)) +Base.metadata.drop_all(engine) diff --git a/src/cnaas_nms/tools/event.py b/src/cnaas_nms/tools/event.py index 09b2b176..d8f91b16 100644 --- a/src/cnaas_nms/tools/event.py +++ b/src/cnaas_nms/tools/event.py @@ -22,16 +22,16 @@ def add_event( Returns: """ - with redis_session() as redis: + with redis_session() as redis: # type: ignore try: send_data = {"type": event_type, "level": level} if event_type == "log": - send_data["message"] = message + send_data["message"] = str(message) elif event_type == "update": - send_data["update_type"] = update_type - send_data["json"] = json_data + send_data["update_type"] = str(update_type) + send_data["json"] = str(json_data) elif event_type == "sync": - send_data["json"] = json_data + send_data["json"] = str(json_data) redis.xadd("events", send_data, maxlen=100) except Exception as e: print("Error in add_event: {}".format(e)) diff --git a/src/cnaas_nms/tools/initdb.py b/src/cnaas_nms/tools/initdb.py index 30b8ab32..9eb9f6f6 100644 --- a/src/cnaas_nms/tools/initdb.py +++ b/src/cnaas_nms/tools/initdb.py @@ -28,7 +28,7 @@ print(Mgmtdomain.__table__) print(Interface.__table__) -print(Base.metadata.create_all(engine)) +Base.metadata.create_all(engine) t = Site() t.description = "default" diff --git a/src/cnaas_nms/tools/jinja_filters.py b/src/cnaas_nms/tools/jinja_filters.py index b51cf098..dd577208 100644 --- a/src/cnaas_nms/tools/jinja_filters.py +++ b/src/cnaas_nms/tools/jinja_filters.py @@ -104,7 +104,9 @@ def isofy_ipv4(ip_string, prefix=""): @template_filter() -def ipv4_to_ipv6(v6_network: Union[str, ipaddress.IPv6Network], v4_address: Union[str, ipaddress.IPv4Interface]): +def ipv4_to_ipv6( + v6_network: str | ipaddress.IPv6Network, v4_address: str | ipaddress.IPv4Interface | ipaddress.IPv4Address +): """Transforms an IPv4 address to an IPv6 interface address. This will combine an arbitrary IPv6 network address with the 32 address bytes of an IPv4 address into a valid IPv6 address + prefix length notation - the equivalent of dotted quad compatible notation. @@ -133,8 +135,9 @@ def ipv4_to_ipv6(v6_network: Union[str, ipaddress.IPv6Network], v4_address: Unio @template_filter() def get_interface( - network: Union[ipaddress.IPv6Interface, ipaddress.IPv4Interface, str], index: int -) -> Union[ipaddress.IPv6Interface, ipaddress.IPv4Interface]: + network: Union[ipaddress.IPv6Interface, ipaddress.IPv4Interface, ipaddress.IPv6Network, ipaddress.IPv4Network, str], + index: int, +) -> Union[ipaddress.IPv6Interface, ipaddress.IPv4Interface, ipaddress.IPv6Network, ipaddress.IPv4Network]: """Returns a host address with a prefix length from its index in a network. Example: @@ -148,8 +151,8 @@ def get_interface( if isinstance(network, str): network = ipaddress.ip_network(network) - host = network[index] - return ipaddress.ip_interface(f"{host}/{network.prefixlen}") + host = network[index] # type: ignore + return ipaddress.ip_interface(f"{host}/{network.prefixlen}") # type: ignore @template_filter() diff --git a/src/cnaas_nms/tools/jinja_helpers.py b/src/cnaas_nms/tools/jinja_helpers.py index 098caa24..fa3b5da6 100644 --- a/src/cnaas_nms/tools/jinja_helpers.py +++ b/src/cnaas_nms/tools/jinja_helpers.py @@ -1,10 +1,12 @@ """Functions that aid in the building of Jinja template contexts""" + import os +from typing import Any def get_environment_secrets(prefix="TEMPLATE_SECRET_"): """Returns a dictionary of secrets stored in environment variables""" - template_secrets = {env: value for env, value in os.environ.items() if env.startswith(prefix)} + template_secrets: dict[str, Any] = {env: value for env, value in os.environ.items() if env.startswith(prefix)} # Also make secrets available as a dict, so keys can be constructed dynamically in templates template_secrets["TEMPLATE_SECRET"] = {env.replace(prefix, ""): value for env, value in template_secrets.items()} diff --git a/src/cnaas_nms/tools/log.py b/src/cnaas_nms/tools/log.py index 8c443f4e..efd2095b 100644 --- a/src/cnaas_nms/tools/log.py +++ b/src/cnaas_nms/tools/log.py @@ -23,13 +23,13 @@ def get_logger(): "[%(asctime)s] %(levelname)s in %(module)s job #{}: %(message)s".format(thread_data.job_id) ) # stdout logging - handler = logging.StreamHandler() - handler.setFormatter(formatter) - logger.addHandler(handler) + stdout_handler = logging.StreamHandler() + stdout_handler.setFormatter(formatter) + logger.addHandler(stdout_handler) # websocket logging - handler = WebsocketHandler() - handler.setFormatter(formatter) - logger.addHandler(handler) + websocket_handler = WebsocketHandler() + websocket_handler.setFormatter(formatter) + logger.addHandler(websocket_handler) elif current_app: logger = current_app.logger else: @@ -37,12 +37,12 @@ def get_logger(): if not logger.handlers: formatter = logging.Formatter("[%(asctime)s] %(levelname)s in %(module)s: %(message)s") # stdout logging - handler = logging.StreamHandler() - handler.setFormatter(formatter) - logger.addHandler(handler) + stdout_handler = logging.StreamHandler() + stdout_handler.setFormatter(formatter) + logger.addHandler(stdout_handler) # websocket logging - handler = WebsocketHandler() - handler.setFormatter(formatter) - logger.addHandler(handler) + websocket_handler = WebsocketHandler() + websocket_handler.setFormatter(formatter) + logger.addHandler(websocket_handler) logger.setLevel(logging.DEBUG) # TODO: get from /etc config ? return logger diff --git a/src/cnaas_nms/tools/oidc/key_management.py b/src/cnaas_nms/tools/oidc/key_management.py index 216b02db..a6369044 100644 --- a/src/cnaas_nms/tools/oidc/key_management.py +++ b/src/cnaas_nms/tools/oidc/key_management.py @@ -1,4 +1,4 @@ -from typing import Any, Mapping, Optional +from typing import Any, List, Mapping import requests from jwt.exceptions import InvalidKeyError @@ -11,13 +11,10 @@ class JWKSStore(object, metaclass=SingletonType): - keys: Mapping[str, Any] + keys: List[Mapping[str, Any]] - def __init__(self, keys: Optional[Mapping[str, Any]] = None): - if keys: - self.keys = keys - else: - self.keys = {} + def __init__(self, keys: List[Mapping[str, Any]] = []): + self.keys = keys def get_keys(): @@ -35,7 +32,7 @@ def get_keys(): raise ConnectionError("Can't retrieve keys") -def get_key(kid): +def get_key(kid: str): """Get the key based on the kid""" jwks_store = JWKSStore() key = [k for k in jwks_store.keys if k["kid"] == kid] diff --git a/src/cnaas_nms/tools/oidc/oidc_client_call.py b/src/cnaas_nms/tools/oidc/oidc_client_call.py index cd44978b..7dcd2ec1 100644 --- a/src/cnaas_nms/tools/oidc/oidc_client_call.py +++ b/src/cnaas_nms/tools/oidc/oidc_client_call.py @@ -1,5 +1,4 @@ import json -from typing import Optional import requests from jwt.exceptions import ExpiredSignatureError, InvalidTokenError @@ -76,7 +75,7 @@ def get_token_info_from_introspect(session: requests.Session, token: Token, intr raise InvalidTokenError("Invalid JSON in introspection response: {}".format(str(e))) -def get_oauth_token_info(token: Token) -> Optional[dict]: +def get_oauth_token_info(token: Token) -> dict: """Give back the details about the token from userinfo or introspection If OIDC is disabled, we return None. @@ -88,12 +87,7 @@ def get_oauth_token_info(token: Token) -> Optional[dict]: resp.json(): Object of the user info or introspection """ - # For now unnecessary, useful when we only use one log in method - if not auth_settings.OIDC_ENABLED: - return None - # Get the cached token info - cached_token_info = get_token_info_from_cache(token) if cached_token_info: return cached_token_info diff --git a/src/cnaas_nms/tools/oidc/token.py b/src/cnaas_nms/tools/oidc/token.py index 6d64109f..648a2ab9 100644 --- a/src/cnaas_nms/tools/oidc/token.py +++ b/src/cnaas_nms/tools/oidc/token.py @@ -1,7 +1,6 @@ class Token: token_string: str = "" - decoded_token = {} - expires_at = "" + decoded_token: dict = {} def __init__(self, token_string: str, decoded_token: dict): self.token_string = token_string diff --git a/src/cnaas_nms/tools/pki.py b/src/cnaas_nms/tools/pki.py index 369488ba..2fd44abb 100644 --- a/src/cnaas_nms/tools/pki.py +++ b/src/cnaas_nms/tools/pki.py @@ -37,13 +37,13 @@ def get_ssl_context(): def generate_device_cert(hostname: str, ipv4_address: IPv4Address): try: - if not os.path.isfile(api_settings.CAFILE): + if not os.path.isfile(str(api_settings.CAFILE)): raise Exception("Specified cafile is not a file: {}".format(api_settings.CAFILE)) except KeyError: raise Exception("No cafile specified in api.yml") try: - if not os.path.isfile(api_settings.CAKEYFILE): + if not os.path.isfile(str(api_settings.CAKEYFILE)): raise Exception("Specified cakeyfile is not a file: {}".format(api_settings.CAKEYFILE)) except KeyError: raise Exception("No cakeyfile specified in api.yml") @@ -54,13 +54,13 @@ def generate_device_cert(hostname: str, ipv4_address: IPv4Address): except KeyError: raise Exception("No certpath found in API settings") - with open(api_settings.CAKEYFILE, "rb") as cakeyfile: + with open(str(api_settings.CAKEYFILE), "rb") as cakeyfile: root_key = serialization.load_pem_private_key( cakeyfile.read(), password=None, ) - with open(api_settings.CAFILE, "rb") as cafile: + with open(str(api_settings.CAFILE), "rb") as cafile: root_cert = x509.load_pem_x509_certificate(cafile.read()) cert_key = rsa.generate_private_key(public_exponent=65537, key_size=2048, backend=default_backend()) @@ -81,7 +81,7 @@ def generate_device_cert(hostname: str, ipv4_address: IPv4Address): x509.SubjectAlternativeName([x509.IPAddress(ipv4_address)]), critical=False, ) - .sign(root_key, hashes.SHA256(), default_backend()) + .sign(root_key, hashes.SHA256(), default_backend()) # type: ignore ) with open(os.path.join(api_settings.CERTPATH, "{}.crt".format(hostname)), "wb") as f: diff --git a/src/cnaas_nms/tools/rbac/rbac.py b/src/cnaas_nms/tools/rbac/rbac.py index 8b7333ae..ea35343b 100644 --- a/src/cnaas_nms/tools/rbac/rbac.py +++ b/src/cnaas_nms/tools/rbac/rbac.py @@ -9,7 +9,7 @@ def get_permissions_user(permissions_rules: PermissionsModel, user_info: dict): """Get the API permissions of the user""" - permissions_of_user = [] + permissions_of_user: list[PermissionModel] = [] # if no rules, return if not permissions_rules: @@ -17,9 +17,9 @@ def get_permissions_user(permissions_rules: PermissionsModel, user_info: dict): # first give all the permissions of the fallback role if permissions_rules.config and permissions_rules.config.default_permissions: - permissions_of_user.extend( - permissions_rules.roles.get(permissions_rules.config.default_permissions).permissions - ) + default_role = permissions_rules.roles.get(permissions_rules.config.default_permissions) + if default_role is not None: + permissions_of_user.extend(default_role.permissions) user_roles: List[str] = [] # read the group mappings and add the relevant roles @@ -50,10 +50,13 @@ def check_if_api_call_is_permitted(request: FlaskJsonRequest, permissions_of_use allowed_methods = permission.methods allowed_endpoints = permission.endpoints + # check if any endpoints or methods allowed + if allowed_endpoints is None or allowed_methods is None: + continue + # check if allowed based on the method if "*" not in allowed_methods and request.method not in allowed_methods: continue - # prepare the uri prefix = "/api/{}".format(__api_version__) short_uri = request.uri.split(prefix, 1)[1].split("?", 1)[0] diff --git a/src/cnaas_nms/tools/security.py b/src/cnaas_nms/tools/security.py index e4aa65c6..ace66cf0 100644 --- a/src/cnaas_nms/tools/security.py +++ b/src/cnaas_nms/tools/security.py @@ -30,7 +30,7 @@ def get_jwt_identity(): class MyBearerTokenValidator(BearerTokenValidator): - def authenticate_token(self, token_string: str) -> Token: + def authenticate_token(self, token_string: str) -> str | Token: """Check if token is active. If JWT is disabled, we return because no token is needed. @@ -60,9 +60,8 @@ def authenticate_token(self, token_string: str) -> Token: raise InvalidTokenError(e) except exceptions.JWTError: # check if we can still authenticate the user with user info - token = Token(token_string, None) - get_oauth_token_info(token) - return token + token = Token(token_string, {}) + return Token(token_string, get_oauth_token_info(token)) # get the key key = get_key(unverified_header.get("kid")) diff --git a/src/cnaas_nms/tools/testsetup.py b/src/cnaas_nms/tools/testsetup.py index 77740c0b..988ec073 100644 --- a/src/cnaas_nms/tools/testsetup.py +++ b/src/cnaas_nms/tools/testsetup.py @@ -74,7 +74,7 @@ def __init__(self, user="cnaas", passwd="cnaas", database="cnaas"): else: self.shutdown() logging.debug("Failed to start postgres") - assert (False, "Could not start postgres") + assert False, "Could not start postgres" # Copy the database dump to the container. subprocess.call(