Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: address linting suggestions #127

Merged
merged 9 commits into from
Apr 11, 2024
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,8 @@ cython_debug/
.DS_Store

.vscode/
.idea/
.idea/

# Notebook Model Downloads
notebooks/PyTorchModels/
pytorch-model-scan-results.json
1 change: 1 addition & 0 deletions modelscan/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""CLI for scanning models"""

import logging

from modelscan._version import __version__
Expand Down
17 changes: 9 additions & 8 deletions modelscan/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
import os
from pathlib import Path
from typing import Optional, Dict, Any
from typing import Optional
from tomlkit import parse

import click
Expand Down Expand Up @@ -87,7 +87,7 @@ def cli() -> None:
help="Optional file name for output report",
)
@cli.command(
help="[Default] Scan a model file or diretory for ability to execute suspicious actions. "
help="[Default] Scan a model file or directory for ability to execute suspicious actions. "
) # type: ignore
@click.pass_context
def scan(
Expand All @@ -112,7 +112,7 @@ def scan(
settings = DEFAULT_SETTINGS

if settings_file_path and settings_file_path.is_file():
with open(settings_file_path) as sf:
with open(settings_file_path, encoding="utf-8") as sf:
settings = parse(sf.read()).unwrap()
click.echo(f"Detected settings file. Using {settings_file_path}. \n")
else:
Expand All @@ -132,7 +132,7 @@ def scan(
raise click.UsageError("Command line must include a path")

# Report scan results
if reporting_format is not "custom":
if reporting_format != "custom":
modelscan._settings["reporting"]["module"] = DEFAULT_REPORTING_MODULES[
reporting_format
]
Expand Down Expand Up @@ -174,16 +174,17 @@ def create_settings(force: bool, location: Optional[str]) -> None:
settings_path = location

try:
open(settings_path)
open(settings_path, encoding="utf-8")
if force:
with open(settings_path, "w") as settings_file:
with open(settings_path, mode="w", encoding="utf-8") as settings_file:
settings_file.write(SettingsUtils.get_default_settings_as_toml())
else:
logger.warning(
f"{settings_path} file already exists. Please use `--force` flag if you intend to overwrite it."
"%s file already exists. Please use `--force` flag if you intend to overwrite it.",
settings_path,
)
except FileNotFoundError:
with open(settings_path, "w") as settings_file:
with open(settings_path, mode="w", encoding="utf-8") as settings_file:
settings_file.write(SettingsUtils.get_default_settings_as_toml())


Expand Down
2 changes: 1 addition & 1 deletion modelscan/issues.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def print(self) -> None:
if self.code == IssueCode.UNSAFE_OPERATOR:
issue_description = "Unsafe operator"
else:
logger.error(f"No issue description for issue code ${self.code}")
logger.error("No issue description for issue code %s", self.code)

print(f"\n{issue_description} found:")
print(f" - Severity: {self.severity.name}")
Expand Down
23 changes: 14 additions & 9 deletions modelscan/modelscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _load_scanners(self) -> None:
self._scanners_to_run.append(scanner_class)

except Exception as e:
logger.error(f"Error importing scanner {scanner_path}")
logger.error("Error importing scanner %s", scanner_path)
self._init_errors.append(
ModelScanError(
scanner_path,
Expand All @@ -81,7 +81,7 @@ def _load_scanners(self) -> None:

def _iterate_models(self, model_path: Path) -> Generator[Model, None, None]:
if not model_path.exists():
logger.error(f"Path {model_path} does not exist")
logger.error("Path %s does not exist", model_path)
self._errors.append(
ModelScanError(
"ModelScan",
Expand All @@ -93,7 +93,7 @@ def _iterate_models(self, model_path: Path) -> Generator[Model, None, None]:

files = [model_path]
if model_path.is_dir():
logger.debug(f"Path {str(model_path)} is a directory")
logger.debug("Path %s is a directory", str(model_path))
files = [f for f in model_path.rglob("*") if Path.is_file(f)]

for file in files:
Expand Down Expand Up @@ -127,8 +127,8 @@ def _iterate_models(self, model_path: Path) -> Generator[Model, None, None]:
yield Model(file_name, file_io)
except zipfile.BadZipFile as e:
logger.debug(
f"Skipping zip file {str(model.get_source())}, due to error",
e,
"Skipping zip file %s, due to error",
str(model.get_source()),
exc_info=True,
)
self._skipped.append(
Expand Down Expand Up @@ -189,7 +189,10 @@ def _scan_source(
scan_results = scanner.scan(model)
except Exception as e:
logger.error(
f"Error encountered from scanner {scanner.full_name()} with path {str(model.get_source())}: {e}"
"Error encountered from scanner %s with path %s: %s",
scanner.full_name(),
str(model.get_source()),
e,
)
self._errors.append(
ModelScanError(
Expand All @@ -204,7 +207,9 @@ def _scan_source(
if scan_results is not None:
scanned = True
logger.info(
f"Scanning {model.get_source()} using {scanner.full_name()} model scan"
"Scanning %s using %s model scan",
model.get_source(),
scanner.full_name(),
)
if scan_results.errors:
self._errors.extend(scan_results.errors)
Expand All @@ -224,7 +229,7 @@ def _scan_source(
ModelScanSkipped(
"ModelScan",
SkipCategories.SCAN_NOT_SUPPORTED,
f"Model Scan did not scan file",
"Model Scan did not scan file",
str(model.get_source()),
)
)
Expand Down Expand Up @@ -343,7 +348,7 @@ def generate_report(self) -> Optional[str]:
scan_report = report_class.generate(scan=self, settings=report_settings)

except Exception as e:
logger.error(f"Error generating report using {reporting_module}: {e}")
logger.error("Error generating report using %s: %s", reporting_module, e)
self._errors.append(
ModelScanError(
"ModelScan",
Expand Down
9 changes: 4 additions & 5 deletions modelscan/reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from rich import print

from modelscan.modelscan import ModelScan
from modelscan.error import Error
from modelscan.issues import Issues, IssueSeverity
from modelscan.issues import IssueSeverity

logger = logging.getLogger("modelscan")

Expand All @@ -33,7 +32,7 @@ def generate(

:param errors: Any errors that occurred during the scan.
"""
raise NotImplemented
raise NotImplementedError


class ConsoleReport(Report):
Expand All @@ -47,7 +46,7 @@ def generate(
total_issue_count = len(scan.issues.all_issues)
if total_issue_count > 0:
print(f"\nTotal Issues: {total_issue_count}")
print(f"\nTotal Issues By Severity:\n")
print("\nTotal Issues By Severity:\n")
for severity in IssueSeverity:
if severity.name in issues_by_severity:
print(
Expand Down Expand Up @@ -76,7 +75,7 @@ def generate(
f"\nTotal skipped: {len(scan.skipped)} - run with --show-skipped to see the full list."
)
if settings["show_skipped"]:
print(f"\nSkipped files list:\n")
print("\nSkipped files list:\n")
for file_name in scan.skipped:
print(str(file_name))

Expand Down
6 changes: 3 additions & 3 deletions modelscan/scanners/h5/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _scan_keras_h5_file(self, model: Model) -> Optional[ScanResults]:
ModelScanError(
self.name(),
ErrorCategories.JSON_DECODE,
f"Not a valid JSON data",
"Not a valid JSON data",
str(model.get_source()),
)
],
Expand All @@ -83,7 +83,7 @@ def _scan_keras_h5_file(self, model: Model) -> Optional[ScanResults]:
ModelScanSkipped(
self.name(),
SkipCategories.MODEL_CONFIG,
f"Model Config not found",
"Model Config not found",
str(model.get_source()),
)
],
Expand All @@ -102,7 +102,7 @@ def _get_keras_h5_operator_names(self, model: Model) -> Optional[List[Any]]:

with h5py.File(model.get_stream()) as model_hdf5:
try:
if not "model_config" in model_hdf5.attrs.keys():
if "model_config" not in model_hdf5.attrs.keys():
return None

model_config = json.loads(model_hdf5.attrs.get("model_config", {}))
Expand Down
7 changes: 3 additions & 4 deletions modelscan/scanners/keras/scan.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import json
import zipfile
import logging
from pathlib import Path
from typing import IO, List, Union, Optional
from typing import List, Optional


from modelscan.error import ModelScanError, ErrorCategories
Expand Down Expand Up @@ -67,7 +66,7 @@ def scan(self, model: Model) -> Optional[ScanResults]:
ModelScanError(
self.name(),
ErrorCategories.MODEL_SCAN, # Giving a generic error category as this return is added to pass mypy
f"Unable to scan .keras file", # Not sure if this is a representative message for ModelScanError
"Unable to scan .keras file", # Not sure if this is a representative message for ModelScanError
str(model.get_source()),
)
],
Expand All @@ -92,7 +91,7 @@ def _scan_keras_config_file(self, model: Model) -> ScanResults:
ModelScanError(
self.name(),
ErrorCategories.JSON_DECODE,
f"Not a valid JSON data",
"Not a valid JSON data",
str(model.get_source()),
)
],
Expand Down
5 changes: 2 additions & 3 deletions modelscan/scanners/saved_model/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

import json
import logging
from pathlib import Path

from typing import IO, List, Set, Union, Optional, Dict, Any
from typing import List, Set, Optional, Dict, Any

try:
import tensorflow
Expand Down Expand Up @@ -121,7 +120,7 @@ def _scan(self, model: Model) -> Optional[ScanResults]:
ModelScanError(
self.name(),
ErrorCategories.JSON_DECODE,
f"Not a valid JSON data",
"Not a valid JSON data",
str(model.get_source()),
)
],
Expand Down
4 changes: 0 additions & 4 deletions modelscan/skip.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import abc
import logging
from enum import Enum
from pathlib import Path
from typing import Any, List, Union, Dict, Optional

from collections import defaultdict

logger = logging.getLogger("modelscan")

Expand Down
2 changes: 1 addition & 1 deletion modelscan/tools/picklescanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def scan_pytorch(model: Model, settings: Dict[str, Any]) -> ScanResults:
ModelScanSkipped(
scan_name,
SkipCategories.MAGIC_NUMBER,
f"Invalid magic number",
"Invalid magic number",
str(model.get_source()),
)
],
Expand Down
2 changes: 0 additions & 2 deletions modelscan/tools/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import http.client
import io
import urllib.parse
from pathlib import Path
from pickletools import genops # nosec
from typing import IO, Optional, Union
Expand Down
Loading
Loading