Skip to content

Commit

Permalink
refactor: address linting suggestions (#127)
Browse files Browse the repository at this point in the history
  • Loading branch information
CandiedCode authored Apr 11, 2024
1 parent 69f09f5 commit 8b10baa
Show file tree
Hide file tree
Showing 17 changed files with 241 additions and 201 deletions.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,8 @@ cython_debug/
.DS_Store

.vscode/
.idea/
.idea/

# Notebook Model Downloads
notebooks/PyTorchModels/
pytorch-model-scan-results.json
1 change: 1 addition & 0 deletions modelscan/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""CLI for scanning models"""

import logging

from modelscan._version import __version__
Expand Down
17 changes: 9 additions & 8 deletions modelscan/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
import os
from pathlib import Path
from typing import Optional, Dict, Any
from typing import Optional
from tomlkit import parse

import click
Expand Down Expand Up @@ -87,7 +87,7 @@ def cli() -> None:
help="Optional file name for output report",
)
@cli.command(
help="[Default] Scan a model file or diretory for ability to execute suspicious actions. "
help="[Default] Scan a model file or directory for ability to execute suspicious actions. "
) # type: ignore
@click.pass_context
def scan(
Expand All @@ -112,7 +112,7 @@ def scan(
settings = DEFAULT_SETTINGS

if settings_file_path and settings_file_path.is_file():
with open(settings_file_path) as sf:
with open(settings_file_path, encoding="utf-8") as sf:
settings = parse(sf.read()).unwrap()
click.echo(f"Detected settings file. Using {settings_file_path}. \n")
else:
Expand All @@ -132,7 +132,7 @@ def scan(
raise click.UsageError("Command line must include a path")

# Report scan results
if reporting_format is not "custom":
if reporting_format != "custom":
modelscan._settings["reporting"]["module"] = DEFAULT_REPORTING_MODULES[
reporting_format
]
Expand Down Expand Up @@ -174,16 +174,17 @@ def create_settings(force: bool, location: Optional[str]) -> None:
settings_path = location

try:
open(settings_path)
open(settings_path, encoding="utf-8")
if force:
with open(settings_path, "w") as settings_file:
with open(settings_path, mode="w", encoding="utf-8") as settings_file:
settings_file.write(SettingsUtils.get_default_settings_as_toml())
else:
logger.warning(
f"{settings_path} file already exists. Please use `--force` flag if you intend to overwrite it."
"%s file already exists. Please use `--force` flag if you intend to overwrite it.",
settings_path,
)
except FileNotFoundError:
with open(settings_path, "w") as settings_file:
with open(settings_path, mode="w", encoding="utf-8") as settings_file:
settings_file.write(SettingsUtils.get_default_settings_as_toml())


Expand Down
2 changes: 1 addition & 1 deletion modelscan/issues.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def print(self) -> None:
if self.code == IssueCode.UNSAFE_OPERATOR:
issue_description = "Unsafe operator"
else:
logger.error(f"No issue description for issue code ${self.code}")
logger.error("No issue description for issue code %s", self.code)

print(f"\n{issue_description} found:")
print(f" - Severity: {self.severity.name}")
Expand Down
23 changes: 14 additions & 9 deletions modelscan/modelscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _load_scanners(self) -> None:
self._scanners_to_run.append(scanner_class)

except Exception as e:
logger.error(f"Error importing scanner {scanner_path}")
logger.error("Error importing scanner %s", scanner_path)
self._init_errors.append(
ModelScanError(
f"Error importing scanner {scanner_path}: {e}",
Expand All @@ -79,12 +79,12 @@ def _load_scanners(self) -> None:

def _iterate_models(self, model_path: Path) -> Generator[Model, None, None]:
if not model_path.exists():
logger.error(f"Path {model_path} does not exist")
logger.error("Path %s does not exist", model_path)
self._errors.append(PathError("Path is not valid", model_path))

files = [model_path]
if model_path.is_dir():
logger.debug(f"Path {str(model_path)} is a directory")
logger.debug("Path %s is a directory", str(model_path))
files = [f for f in model_path.rglob("*") if Path.is_file(f)]

for file in files:
Expand Down Expand Up @@ -116,8 +116,8 @@ def _iterate_models(self, model_path: Path) -> Generator[Model, None, None]:
yield Model(file_name, file_io)
except zipfile.BadZipFile as e:
logger.debug(
f"Skipping zip file {str(model.get_source())}, due to error",
e,
"Skipping zip file %s, due to error",
str(model.get_source()),
exc_info=True,
)
self._skipped.append(
Expand Down Expand Up @@ -178,7 +178,10 @@ def _scan_source(
scan_results = scanner.scan(model)
except Exception as e:
logger.error(
f"Error encountered from scanner {scanner.full_name()} with path {str(model.get_source())}: {e}"
"Error encountered from scanner %s with path %s: %s",
scanner.full_name(),
str(model.get_source()),
e,
)
self._errors.append(
ModelScanScannerError(
Expand All @@ -192,7 +195,9 @@ def _scan_source(
if scan_results is not None:
scanned = True
logger.info(
f"Scanning {model.get_source()} using {scanner.full_name()} model scan"
"Scanning %s using %s model scan",
model.get_source(),
scanner.full_name(),
)
if scan_results.errors:
self._errors.extend(scan_results.errors)
Expand All @@ -212,7 +217,7 @@ def _scan_source(
ModelScanSkipped(
"ModelScan",
SkipCategories.SCAN_NOT_SUPPORTED,
f"Model Scan did not scan file",
"Model Scan did not scan file",
str(model.get_source()),
)
)
Expand Down Expand Up @@ -328,7 +333,7 @@ def generate_report(self) -> Optional[str]:
scan_report = report_class.generate(scan=self, settings=report_settings)

except Exception as e:
logger.error(f"Error generating report using {reporting_module}: {e}")
logger.error("Error generating report using %s: %s", reporting_module, e)
self._errors.append(
ModelScanError(f"Error generating report using {reporting_module}: {e}")
)
Expand Down
6 changes: 3 additions & 3 deletions modelscan/reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def generate(
:param errors: Any errors that occurred during the scan.
"""
raise NotImplemented
raise NotImplementedError


class ConsoleReport(Report):
Expand All @@ -46,7 +46,7 @@ def generate(
total_issue_count = len(scan.issues.all_issues)
if total_issue_count > 0:
print(f"\nTotal Issues: {total_issue_count}")
print(f"\nTotal Issues By Severity:\n")
print("\nTotal Issues By Severity:\n")
for severity in IssueSeverity:
if severity.name in issues_by_severity:
print(
Expand Down Expand Up @@ -75,7 +75,7 @@ def generate(
f"\nTotal skipped: {len(scan.skipped)} - run with --show-skipped to see the full list."
)
if settings["show_skipped"]:
print(f"\nSkipped files list:\n")
print("\nSkipped files list:\n")
for file_name in scan.skipped:
print(str(file_name))

Expand Down
6 changes: 3 additions & 3 deletions modelscan/scanners/h5/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _scan_keras_h5_file(self, model: Model) -> Optional[ScanResults]:
[
JsonDecodeError(
self.name(),
f"Not a valid JSON data",
"Not a valid JSON data",
model,
)
],
Expand All @@ -85,7 +85,7 @@ def _scan_keras_h5_file(self, model: Model) -> Optional[ScanResults]:
ModelScanSkipped(
self.name(),
SkipCategories.MODEL_CONFIG,
f"Model Config not found",
"Model Config not found",
str(model.get_source()),
)
],
Expand All @@ -104,7 +104,7 @@ def _get_keras_h5_operator_names(self, model: Model) -> Optional[List[Any]]:

with h5py.File(model.get_stream()) as model_hdf5:
try:
if not "model_config" in model_hdf5.attrs.keys():
if "model_config" not in model_hdf5.attrs.keys():
return None

model_config = json.loads(model_hdf5.attrs.get("model_config", {}))
Expand Down
4 changes: 2 additions & 2 deletions modelscan/scanners/keras/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def scan(self, model: Model) -> Optional[ScanResults]:
[
ModelScanScannerError(
self.name(),
f"Unable to scan .keras file", # Not sure if this is a representative message for ModelScanError
"Unable to scan .keras file", # Not sure if this is a representative message for ModelScanError
model,
)
],
Expand All @@ -89,7 +89,7 @@ def _scan_keras_config_file(self, model: Model) -> ScanResults:
[
JsonDecodeError(
self.name(),
f"Not a valid JSON data",
"Not a valid JSON data",
model,
)
],
Expand Down
5 changes: 2 additions & 3 deletions modelscan/scanners/saved_model/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

import json
import logging
from pathlib import Path

from typing import IO, List, Set, Union, Optional, Dict, Any
from typing import List, Set, Optional, Dict, Any

try:
import tensorflow
Expand Down Expand Up @@ -123,7 +122,7 @@ def _scan(self, model: Model) -> Optional[ScanResults]:
[
JsonDecodeError(
self.name(),
f"Not a valid JSON data",
"Not a valid JSON data",
model,
)
],
Expand Down
4 changes: 0 additions & 4 deletions modelscan/skip.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import abc
import logging
from enum import Enum
from pathlib import Path
from typing import Any, List, Union, Dict, Optional

from collections import defaultdict

logger = logging.getLogger("modelscan")

Expand Down
2 changes: 1 addition & 1 deletion modelscan/tools/picklescanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def scan_pytorch(model: Model, settings: Dict[str, Any]) -> ScanResults:
ModelScanSkipped(
scan_name,
SkipCategories.MAGIC_NUMBER,
f"Invalid magic number",
"Invalid magic number",
str(model.get_source()),
)
],
Expand Down
2 changes: 0 additions & 2 deletions modelscan/tools/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import http.client
import io
import urllib.parse
from pathlib import Path
from pickletools import genops # nosec
from typing import IO, Optional, Union
Expand Down
Loading

0 comments on commit 8b10baa

Please sign in to comment.