Skip to content

Commit

Permalink
Merge branch 'main' into iamfaisalkhan-patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
iamfaisalkhan authored Apr 8, 2024
2 parents 509001a + d5f289c commit 37cfaff
Show file tree
Hide file tree
Showing 9 changed files with 151 additions and 104 deletions.
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,24 @@ default_language_version:
python: python3.11
repos:
- repo: https://github.com/psf/black
rev: 22.8.0
rev: 24.3.0
hooks:
- id: black
- repo: https://github.com/python-poetry/poetry
rev: "1.4.0"
rev: "1.7.1"
hooks:
- id: poetry-check # Makes sure poetry config is valid
- id: poetry-lock # Makes sure lock file is up to date
args: ["--check"]
- repo: https://github.com/PyCQA/bandit
rev: "1.7.5"
rev: "1.7.8"
hooks:
- id: bandit
args: ["-c", "pyproject.toml"]
additional_dependencies: ["bandit[toml]"]
exclude: notebooks
- repo: https://github.com/pre-commit/mirrors-mypy
rev: "v1.4.1"
rev: "v1.8.0"
hooks:
- id: mypy
args: ["--ignore-missing-imports", "--strict", "--check-untyped-defs"]
Expand Down
119 changes: 93 additions & 26 deletions modelscan/error.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,111 @@
from typing import Optional
from enum import Enum
from modelscan.model import Model
import abc
from pathlib import Path
from typing import Dict


class ErrorCategories(Enum):
MODEL_SCAN = 1
DEPENDENCY = 2
PATH = 3
NESTED_ZIP = 4
PICKLE_GENOPS = 5
JSON_DECODE = 6


class Error:
scan_name: str
category: ErrorCategories
class ErrorBase(metaclass=abc.ABCMeta):
message: str
source: Optional[str]

def __init__(self) -> None:
pass
def __init__(self, message: str) -> None:
self.message = message

@abc.abstractmethod
def __str__(self) -> str:
raise NotImplementedError()

@staticmethod
@abc.abstractmethod
def name() -> str:
raise NotImplementedError

def to_dict(self) -> Dict[str, str]:
return {
"category": self.name(),
"description": self.message,
}


class ModelScanError(ErrorBase):
def __str__(self) -> str:
return f"The following error was raised: \n{self.message}"

@staticmethod
def name() -> str:
return "MODEL_SCAN"


class ModelScanScannerError(ModelScanError):
scan_name: str
model: Model

class ModelScanError(Error):
def __init__(
self,
scan_name: str,
category: ErrorCategories,
message: str,
source: Optional[str] = None,
model: Model,
) -> None:
super().__init__(message)
self.scan_name = scan_name
self.category = category
self.message = message
self.source = source
self.model = model

def __str__(self) -> str:
if self.source:
return f"The following error was raised during a {self.scan_name} scan of file {self.source}: \n{self.message}"
else:
return f"The following error was raised during a {self.scan_name} scan: \n{self.message}"
return f"The following error was raised during a {self.scan_name} scan: \n{self.message}"

def to_dict(self) -> Dict[str, str]:
return {
"category": self.name(),
"description": self.message,
"source": str(self.model.get_source()),
}


class DependencyError(ModelScanScannerError):
@staticmethod
def name() -> str:
return "DEPENDENCY"


class PathError(ErrorBase):
path: Path

def __init__(
self,
message: str,
path: Path,
) -> None:
super().__init__(message)
self.path = path

def __str__(self) -> str:
return f"The following error was raised during scan of file {str(self.path)}: \n{self.message}"

@staticmethod
def name() -> str:
return "PATH"

def to_dict(self) -> Dict[str, str]:
return {
"category": self.name(),
"description": self.message,
"source": str(self.path),
}


class NestedZipError(PathError):
@staticmethod
def name() -> str:
return "NESTED_ZIP"


class PickleGenopsError(ModelScanScannerError):
@staticmethod
def name() -> str:
return "PICKLE_GENOPS"


class JsonDecodeError(ModelScanScannerError):
@staticmethod
def name() -> str:
return "JSON_DECODE"
61 changes: 21 additions & 40 deletions modelscan/modelscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
from datetime import datetime
import zipfile

from modelscan.error import ModelScanError, ErrorCategories
from modelscan.error import (
ModelScanError,
PathError,
ErrorBase,
ModelScanScannerError,
NestedZipError,
)
from modelscan.skip import ModelScanSkipped, SkipCategories
from modelscan.issues import Issues, IssueSeverity
from modelscan.scanners.scan import ScanBase
Expand All @@ -27,7 +33,7 @@ def __init__(
) -> None:
# Output
self._issues = Issues()
self._errors: List[ModelScanError] = []
self._errors: List[ErrorBase] = []
self._init_errors: List[ModelScanError] = []
self._skipped: List[ModelScanSkipped] = []
self._scanned: List[str] = []
Expand All @@ -46,13 +52,7 @@ def _load_middlewares(self) -> None:
)
except MiddlewareImportError as e:
logger.exception(e)
self._init_errors.append(
ModelScanError(
"MiddlewarePipeline",
ErrorCategories.MODEL_SCAN,
f"Error loading middlewares: {e}",
)
)
self._init_errors.append(ModelScanError(f"Error loading middlewares: {e}"))

def _load_scanners(self) -> None:
for scanner_path, scanner_settings in self._settings["scanners"].items():
Expand All @@ -73,23 +73,14 @@ def _load_scanners(self) -> None:
logger.error(f"Error importing scanner {scanner_path}")
self._init_errors.append(
ModelScanError(
scanner_path,
ErrorCategories.MODEL_SCAN,
f"Error importing scanner: {e}",
f"Error importing scanner {scanner_path}: {e}",
)
)

def _iterate_models(self, model_path: Path) -> Generator[Model, None, None]:
if not model_path.exists():
logger.error(f"Path {model_path} does not exist")
self._errors.append(
ModelScanError(
"ModelScan",
ErrorCategories.PATH,
"Path is not valid",
str(model_path),
)
)
self._errors.append(PathError("Path is not valid", model_path))

files = [model_path]
if model_path.is_dir():
Expand All @@ -115,11 +106,9 @@ def _iterate_models(self, model_path: Path) -> Generator[Model, None, None]:
file_name = f"{model.get_source()}:{file_name}"
if _is_zipfile(file_name, data=file_io):
self._errors.append(
ModelScanError(
"ModelScan",
ErrorCategories.NESTED_ZIP,
NestedZipError(
"ModelScan does not support nested zip files.",
file_name,
Path(file_name),
)
)
continue
Expand Down Expand Up @@ -192,11 +181,10 @@ def _scan_source(
f"Error encountered from scanner {scanner.full_name()} with path {str(model.get_source())}: {e}"
)
self._errors.append(
ModelScanError(
ModelScanScannerError(
scanner.full_name(),
ErrorCategories.MODEL_SCAN,
f"Error encountered from scanner {scanner.full_name()}: {e}",
str(model.get_source()),
str(e),
model,
)
)
continue
Expand Down Expand Up @@ -282,12 +270,9 @@ def _generate_results(self) -> Dict[str, Any]:
all_errors = []
if self._errors:
for error in self._errors:
error_information = {}
error_information["category"] = str(error.category.name)
if error.message:
error_information["description"] = error.message
if error.source is not None:
resolved_file = Path(error.source).resolve()
error_information = error.to_dict()
if "source" in error_information:
resolved_file = Path(error_information["source"]).resolve()
error_information["source"] = str(
resolved_file.relative_to(Path(absolute_path))
)
Expand Down Expand Up @@ -345,11 +330,7 @@ def generate_report(self) -> Optional[str]:
except Exception as e:
logger.error(f"Error generating report using {reporting_module}: {e}")
self._errors.append(
ModelScanError(
"ModelScan",
ErrorCategories.MODEL_SCAN,
f"Error generating report using {reporting_module}: {e}",
)
ModelScanError(f"Error generating report using {reporting_module}: {e}")
)

return scan_report
Expand All @@ -359,7 +340,7 @@ def issues(self) -> Issues:
return self._issues

@property
def errors(self) -> List[ModelScanError]:
def errors(self) -> List[ErrorBase]:
return self._errors

@property
Expand Down
3 changes: 1 addition & 2 deletions modelscan/reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from rich import print

from modelscan.modelscan import ModelScan
from modelscan.error import Error
from modelscan.issues import Issues, IssueSeverity
from modelscan.issues import IssueSeverity

logger = logging.getLogger("modelscan")

Expand Down
16 changes: 9 additions & 7 deletions modelscan/scanners/h5/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
except ImportError:
h5py_installed = False

from modelscan.error import ModelScanError, ErrorCategories
from modelscan.error import (
DependencyError,
JsonDecodeError,
)
from modelscan.skip import ModelScanSkipped, SkipCategories
from modelscan.scanners.scan import ScanResults
from modelscan.scanners.saved_model.scan import SavedModelLambdaDetectScan
Expand All @@ -32,10 +35,10 @@ def scan(
return ScanResults(
[],
[
ModelScanError(
DependencyError(
self.name(),
ErrorCategories.DEPENDENCY,
f"To use {self.full_name()}, please install modelscan with h5py extras. `pip install 'modelscan[ h5py ]'` if you are using pip.",
model,
)
],
[],
Expand All @@ -58,11 +61,10 @@ def _scan_keras_h5_file(self, model: Model) -> Optional[ScanResults]:
return ScanResults(
[],
[
ModelScanError(
JsonDecodeError(
self.name(),
ErrorCategories.JSON_DECODE,
f"Not a valid JSON data",
str(model.get_source()),
model,
)
],
[],
Expand Down Expand Up @@ -128,7 +130,7 @@ def handle_binary_dependencies(
self, settings: Optional[Dict[str, Any]] = None
) -> Optional[str]:
if not h5py_installed:
return ErrorCategories.DEPENDENCY.name
return DependencyError.name()
return None

@staticmethod
Expand Down
Loading

0 comments on commit 37cfaff

Please sign in to comment.