Skip to content

Commit

Permalink
add mypy to pre-commit and solved mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
Josephine.Rutten committed Oct 28, 2024
1 parent 4916854 commit 2d69f8b
Show file tree
Hide file tree
Showing 72 changed files with 713 additions and 639 deletions.
18 changes: 18 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,21 @@ repos:
- id: ruff
args: [ --fix, --exit-non-zero-on-fix, --show-fixes ]
exclude: (alembic/.*)
- repo: local
hooks:
- id: mypy
name: mypy
language: system
entry: "mypy"
types: [ python ]
require_serial: true
verbose: true
args:
- --no-warn-incomplete-stub
- --ignore-missing-imports
- --no-warn-unused-ignores
- --allow-untyped-decorators
- --no-namespace-packages
- --allow-untyped-globals
- --check-untyped-defs
exclude: "(alembic/.*)|(.*test.*)"
2 changes: 1 addition & 1 deletion src/cnaas_nms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ def setup_package():
from cnaas_nms.api.tests.app_wrapper import TestAppWrapper

app = cnaas_nms.api.app.app
app.wsgi_app = TestAppWrapper(app.wsgi_app, None)
app.wsgi_app = TestAppWrapper(app.wsgi_app, None) # type: ignore
client = app.test_client()
data = {"action": "refresh"}
client.put("/api/v1.0/repository/settings", json=data)
Expand Down
11 changes: 6 additions & 5 deletions src/cnaas_nms/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ def socketio_on_connect():
if auth_settings.OIDC_ENABLED:
try:
token = oauth_required.get_token_validator("bearer").authenticate_token(token_string)
user = get_oauth_token_info(token)[auth_settings.OIDC_USERNAME_ATTRIBUTE]
token_info = get_oauth_token_info(token)
user = token_info[auth_settings.OIDC_USERNAME_ATTRIBUTE]
except InvalidTokenError as e:
logger.debug("InvalidTokenError: " + format(e))
return False
Expand Down Expand Up @@ -229,21 +230,21 @@ def log_request(response):
elif request.method in ["GET", "POST", "PUT", "DELETE", "PATCH"]:
try:
if auth_settings.OIDC_ENABLED:
token_string = request.headers.get("Authorization").split(" ")[-1]
token_string = str(request.headers.get("Authorization")).split(" ")[-1]
token = oauth_required.get_token_validator("bearer").authenticate_token(token_string)
token_info = get_oauth_token_info(token)
if auth_settings.OIDC_USERNAME_ATTRIBUTE in token_info:
if token_info is not None and auth_settings.OIDC_USERNAME_ATTRIBUTE in token_info:
user = "User: {} ({}), ".format(
get_oauth_token_info(token)[auth_settings.OIDC_USERNAME_ATTRIBUTE],
auth_settings.OIDC_USERNAME_ATTRIBUTE,
)
elif "client_id" in token_info:
elif token_info is not None and "client_id" in token_info:
user = "User: {} (client_id), ".format(get_oauth_token_info(token)["client_id"])
else:
logger.warning("Could not get user info from token")
raise ValueError
else:
token_string = request.headers.get("Authorization").split(" ")[-1]
token_string = str(request.headers.get("Authorization")).split(" ")[-1]
user = "User: {}, ".format(decode_token(token_string).get("sub"))
except Exception:
user = "User: unknown, "
Expand Down
6 changes: 3 additions & 3 deletions src/cnaas_nms/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def get(self):

req = PreparedRequest()
req.prepare_url(url, parameters)
resp = redirect(req.url, code=302)
resp = redirect(str(req.url), code=302)
if "refresh_token" in token:
resp.set_cookie(
"REFRESH_TOKEN",
Expand All @@ -106,7 +106,7 @@ class RefreshApi(Resource):
def post(self):
oauth_client = current_app.extensions["authlib.integrations.flask_client"]
oauth_client_connext: FlaskOAuth2App = oauth_client.connext
token_string = request.headers.get("Authorization").split(" ")[-1]
token_string = str(request.headers.get("Authorization")).split(" ")[-1]
oauth_client_connext.token = token_string
oauth_client_connext.load_server_metadata()
url = oauth_client_connext.server_metadata["token_endpoint"]
Expand Down Expand Up @@ -155,7 +155,7 @@ def get(self):
logger.debug("No permissions defined, so nobody is permitted to do any api calls.")
return []
user_info = get_oauth_token_info(current_token)
permissions_of_user = get_permissions_user(permissions_rules, user_info)
permissions_of_user = get_permissions_user(permissions_rules, user_info) # check check

# convert to dictionaries so it can be converted to json
permissions_as_dics = []
Expand Down
73 changes: 38 additions & 35 deletions src/cnaas_nms/api/device.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime
import json
from typing import List, Optional
from typing import Any, List, Optional

from flask import make_response, request
from flask_restx import Namespace, Resource, fields, marshal
Expand Down Expand Up @@ -242,7 +242,7 @@ def get(self, device_id):
"""Get a device from ID"""
result = empty_result()
result["data"] = {"devices": []}
with sqla_session() as session:
with sqla_session() as session: # type: ignore
instance = session.query(Device).filter(Device.id == device_id).one_or_none()
if instance:
result["data"]["devices"] = device_data_postprocess([instance])
Expand Down Expand Up @@ -273,7 +273,8 @@ def delete(self, device_id):
return res
elif not isinstance(json_data["factory_default"], bool):
return empty_result(status="error", data="Argument factory_default must be boolean"), 400
with sqla_session() as session:

with sqla_session() as session: # type: ignore
dev: Device = session.query(Device).filter(Device.id == device_id).one_or_none()
if not dev:
return empty_result("error", "Device not found"), 404
Expand Down Expand Up @@ -305,7 +306,7 @@ def delete(self, device_id):
def put(self, device_id):
"""Modify device from ID"""
json_data = request.get_json()
with sqla_session() as session:
with sqla_session() as session: # type: ignore
dev: Device = session.query(Device).filter(Device.id == device_id).one_or_none()
dev_prev_state: DeviceState = dev.state

Expand Down Expand Up @@ -350,7 +351,7 @@ def get(self, hostname):
"""Get a device from hostname"""
result = empty_result()
result["data"] = {"devices": []}
with sqla_session() as session:
with sqla_session() as session: # type: ignore
instance = session.query(Device).filter(Device.hostname == hostname).one_or_none()
if instance:
result["data"]["devices"] = device_data_postprocess([instance])
Expand All @@ -371,7 +372,7 @@ def post(self):
data, errors = Device.validate(**json_data)
if errors != []:
return empty_result(status="error", data=errors), 400
with sqla_session() as session:
with sqla_session() as session: # type: ignore
instance: Device = session.query(Device).filter(Device.hostname == data["hostname"]).one_or_none()
if instance:
errors.append("Device already exists")
Expand Down Expand Up @@ -403,7 +404,7 @@ def get(self):
logger.info("started get devices")
device_list: List[Device] = []
total_count = 0
with sqla_session() as session:
with sqla_session() as session: # type: ignore
query = session.query(Device, func.count(Device.id).over().label("total"))
try:
query = build_filter(Device, query)
Expand All @@ -416,7 +417,7 @@ def get(self):

resp = make_response(json.dumps(empty_result(status="success", data=data)), 200)
resp.headers["Content-Type"] = "application/json"
resp.headers = {**resp.headers, **pagination_headers(total_count)}
resp.headers = {**resp.headers, **pagination_headers(total_count)} # type: ignore
return resp


Expand All @@ -433,7 +434,7 @@ def post(self, device_id: int):

# If device init is already in progress, reschedule a new step2 (connectivity check)
# instead of trying to restart initialization
with sqla_session() as session:
with sqla_session() as session: # type: ignore
dev: Device = session.query(Device).filter(Device.id == device_id).one_or_none()
if (
dev
Expand Down Expand Up @@ -475,14 +476,14 @@ def post(self, device_id: int):
else:
return empty_result(status="error", data="Unsupported 'device_type' provided"), 400

res = empty_result(data=f"Scheduled job to initialize device_id { device_id }")
res = empty_result(data=f"Scheduled job to initialize device_id { str(device_id) }")
res["job_id"] = job_id

return res

@classmethod
def arg_check(cls, device_id: int, json_data: dict) -> dict:
parsed_args = {"device_id": device_id}
parsed_args: dict[str, Any] = {"device_id": device_id}
if not isinstance(device_id, int):
raise ValueError("'device_id' must be an integer")

Expand Down Expand Up @@ -540,22 +541,23 @@ class DeviceInitCheckApi(Resource):
def post(self, device_id: int):
"""Perform init check on a device"""
json_data = request.get_json()
ret = {}
ret: dict[str, Any] = {}
linknets_all = []
mlag_peer_dev: Optional[Device]
try:
parsed_args = DeviceInitApi.arg_check(device_id, json_data)
target_devtype = DeviceType[parsed_args["device_type"]]
target_hostname = parsed_args["new_hostname"]
mlag_peer_target_hostname: Optional[str] = None
mlag_peer_id: Optional[int] = None
mlag_peer_dev: Optional[Device] = None
mlag_peer_dev = None
if "mlag_peer_id" in parsed_args and "mlag_peer_new_hostname" in parsed_args:
mlag_peer_target_hostname = parsed_args["mlag_peer_new_hostname"]
mlag_peer_id = parsed_args["mlag_peer_id"]
except ValueError as e:
return empty_result(status="error", data="Error parsing arguments: {}".format(e)), 400

with sqla_session() as session:
with sqla_session() as session: # type: ignore
try:
dev: Device = cnaas_nms.devicehandler.init_device.pre_init_checks(session, device_id)
linknets_all = dev.get_linknets_as_dict(session)
Expand All @@ -566,7 +568,7 @@ def post(self, device_id: int):

if mlag_peer_id:
try:
mlag_peer_dev: Device = cnaas_nms.devicehandler.init_device.pre_init_checks(session, mlag_peer_id)
mlag_peer_dev = cnaas_nms.devicehandler.init_device.pre_init_checks(session, mlag_peer_id)
linknets_all += mlag_peer_dev.get_linknets_as_dict(session)
except ValueError as e:
return empty_result(status="error", data="ValueError in pre_init_checks: {}".format(e)), 400
Expand Down Expand Up @@ -743,7 +745,7 @@ def post(self):

resp = make_response(json.dumps(res), 200)
if total_count:
resp.headers["X-Total-Count"] = total_count
resp.headers["X-Total-Count"] = str(total_count)
resp.headers["Content-Type"] = "application/json"
return resp

Expand Down Expand Up @@ -787,7 +789,7 @@ def post(self, hostname: str):

resp = make_response(json.dumps(res), 200)
if total_count:
resp.headers["X-Total-Count"] = total_count
resp.headers["X-Total-Count"] = str(total_count)
resp.headers["Content-Type"] = "application/json"
return resp

Expand All @@ -806,7 +808,7 @@ def post(self):
hostname = str(json_data["hostname"])
if not Device.valid_hostname(hostname):
return empty_result(status="error", data=f"Hostname '{hostname}' is not a valid hostname"), 400
with sqla_session() as session:
with sqla_session() as session: # type: ignore
dev: Device = session.query(Device).filter(Device.hostname == hostname).one_or_none()
if not dev or (dev.state != DeviceState.MANAGED and dev.state != DeviceState.UNMANAGED):
return (
Expand All @@ -828,7 +830,7 @@ def post(self):

resp = make_response(json.dumps(res), 200)
if total_count:
resp.headers["X-Total-Count"] = total_count
resp.headers["X-Total-Count"] = str(total_count)
resp.headers["Content-Type"] = "application/json"
return resp

Expand All @@ -847,7 +849,7 @@ def post(self):
hostname = str(json_data["hostname"])
if not Device.valid_hostname(hostname):
return empty_result(status="error", data=f"Hostname '{hostname}' is not a valid hostname"), 400
with sqla_session() as session:
with sqla_session() as session: # type: ignore
dev: Device = session.query(Device).filter(Device.hostname == hostname).one_or_none()
if not dev or (dev.state != DeviceState.MANAGED and dev.state != DeviceState.UNMANAGED):
return (
Expand All @@ -873,8 +875,8 @@ def post(self):
empty_result(status="error", data=f"Hostname '{mlag_peer_hostname}' is not a valid hostname"),
400,
)
with sqla_session() as session:
dev: Device = session.query(Device).filter(Device.hostname == mlag_peer_hostname).one_or_none()
with sqla_session() as session: # type: ignore
dev = session.query(Device).filter(Device.hostname == mlag_peer_hostname).one_or_none()
if not dev or (dev.state != DeviceState.MANAGED and dev.state != DeviceState.UNMANAGED):
return (
empty_result(
Expand Down Expand Up @@ -907,7 +909,7 @@ def post(self):

resp = make_response(json.dumps(res), 200)
if total_count:
resp.headers["X-Total-Count"] = total_count
resp.headers["X-Total-Count"] = str(total_count)
resp.headers["Content-Type"] = "application/json"
return resp

Expand Down Expand Up @@ -956,7 +958,7 @@ def get(self, hostname: str):
if not Device.valid_hostname(hostname):
return empty_result(status="error", data="Invalid hostname specified"), 400

with sqla_session() as session:
with sqla_session() as session: # type: ignore
dev: Device = session.query(Device).filter(Device.hostname == hostname).one_or_none()
if not dev:
return empty_result("error", "Device not found"), 404
Expand Down Expand Up @@ -987,7 +989,7 @@ def get(self, hostname: str):
if not Device.valid_hostname(hostname):
return empty_result(status="error", data="Invalid hostname specified"), 400

kwargs = {}
kwargs: dict[str, Any] = {}
if "job_id" in args:
try:
kwargs["job_id"] = int(args["job_id"])
Expand All @@ -1004,7 +1006,7 @@ def get(self, hostname: str):
except Exception:
return empty_result("error", "before must be a valid ISO format date time string"), 400

with sqla_session() as session:
with sqla_session() as session: # type: ignore
try:
result["data"] = Job.get_previous_config(session, hostname, **kwargs)
except JobNotFoundError as e:
Expand All @@ -1021,7 +1023,7 @@ def get(self, hostname: str):
def post(self, hostname: str):
"""Restore configuration to previous version"""
json_data = request.get_json()
apply_kwargs = {"hostname": hostname}
apply_kwargs: dict[str, Any] = {"hostname": hostname}
config = None
if not Device.valid_hostname(hostname):
return empty_result(status="error", data="Invalid hostname specified"), 400
Expand All @@ -1034,7 +1036,7 @@ def post(self, hostname: str):
else:
return empty_result("error", "job_id must be specified"), 400

with sqla_session() as session:
with sqla_session() as session: # type: ignore
try:
prev_config_result = Job.get_previous_config(session, hostname, job_id=job_id)
failed = prev_config_result["failed"]
Expand Down Expand Up @@ -1080,7 +1082,7 @@ class DeviceApplyConfigApi(Resource):
def post(self, hostname: str):
"""Apply exact specified configuration to device without using templates"""
json_data = request.get_json()
apply_kwargs = {"hostname": hostname}
apply_kwargs: dict[str, Any] = {"hostname": hostname}
allow_live_run = api_settings.ALLOW_APPLY_CONFIG_LIVERUN
if not Device.valid_hostname(hostname):
return empty_result(status="error", data="Invalid hostname specified"), 400
Expand Down Expand Up @@ -1165,7 +1167,7 @@ def post(self):

resp = make_response(json.dumps(res), 200)
if total_count:
resp.headers["X-Total-Count"] = total_count
resp.headers["X-Total-Count"] = str(total_count)
resp.headers["Content-Type"] = "application/json"
return resp
else:
Expand All @@ -1177,7 +1179,7 @@ class DeviceStackmembersApi(Resource):
def get(self, hostname):
"""Get stackmembers for device"""
result = empty_result(data={"stackmembers": []})
with sqla_session() as session:
with sqla_session() as session: # type: ignore
device = session.query(Device).filter(Device.hostname == hostname).one_or_none()
if not device:
return empty_result("error", "Device not found"), 404
Expand All @@ -1196,7 +1198,7 @@ def put(self, hostname):
errors = DeviceStackmembersApi.format_errors(e.errors())
return empty_result("error", errors), 400
result = empty_result(data={"stackmembers": []})
with sqla_session() as session:
with sqla_session() as session: # type: ignore
device_instance = session.query(Device).filter(Device.hostname == hostname).one_or_none()
if not device_instance:
return empty_result("error", "Device not found"), 404
Expand Down Expand Up @@ -1232,13 +1234,14 @@ def get(self):
args = request.args
result = empty_result()
result["data"] = {"hostnames": {}}
sync_history: SyncHistory

if "hostname" in args:
if not Device.valid_hostname(args["hostname"]):
return empty_result(status="error", data="Invalid hostname specified"), 400
sync_history: SyncHistory = get_sync_events([args["hostname"]])
sync_history = get_sync_events([args["hostname"]])
else:
sync_history: SyncHistory = get_sync_events()
sync_history = get_sync_events()

result["data"]["hostnames"] = sync_history.asdict()
return result
Expand All @@ -1250,7 +1253,7 @@ def post(self):
validated_json_data = NewSyncEventModel(**request.get_json()).model_dump()
except ValidationError as e:
return empty_result("error", parse_pydantic_error(e, NewSyncEventModel, request.get_json())), 400
with sqla_session() as session:
with sqla_session() as session: # type: ignore
device_instance = (
session.query(Device).filter(Device.hostname == validated_json_data["hostname"]).one_or_none()
)
Expand Down
Loading

0 comments on commit 2d69f8b

Please sign in to comment.