Skip to content

Commit

Permalink
Proposal to improve errors (#122)
Browse files Browse the repository at this point in the history
Co-authored-by: asofter <[email protected]>
Co-authored-by: Faisal Khan <[email protected]>
  • Loading branch information
3 people authored Apr 8, 2024
1 parent 9b92478 commit d5f289c
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 100 deletions.
119 changes: 93 additions & 26 deletions modelscan/error.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,111 @@
from typing import Optional
from enum import Enum
from modelscan.model import Model
import abc
from pathlib import Path
from typing import Dict


class ErrorCategories(Enum):
MODEL_SCAN = 1
DEPENDENCY = 2
PATH = 3
NESTED_ZIP = 4
PICKLE_GENOPS = 5
JSON_DECODE = 6


class Error:
scan_name: str
category: ErrorCategories
class ErrorBase(metaclass=abc.ABCMeta):
message: str
source: Optional[str]

def __init__(self) -> None:
pass
def __init__(self, message: str) -> None:
self.message = message

@abc.abstractmethod
def __str__(self) -> str:
raise NotImplementedError()

@staticmethod
@abc.abstractmethod
def name() -> str:
raise NotImplementedError

def to_dict(self) -> Dict[str, str]:
return {
"category": self.name(),
"description": self.message,
}


class ModelScanError(ErrorBase):
def __str__(self) -> str:
return f"The following error was raised: \n{self.message}"

@staticmethod
def name() -> str:
return "MODEL_SCAN"


class ModelScanScannerError(ModelScanError):
scan_name: str
model: Model

class ModelScanError(Error):
def __init__(
self,
scan_name: str,
category: ErrorCategories,
message: str,
source: Optional[str] = None,
model: Model,
) -> None:
super().__init__(message)
self.scan_name = scan_name
self.category = category
self.message = message
self.source = source
self.model = model

def __str__(self) -> str:
if self.source:
return f"The following error was raised during a {self.scan_name} scan of file {self.source}: \n{self.message}"
else:
return f"The following error was raised during a {self.scan_name} scan: \n{self.message}"
return f"The following error was raised during a {self.scan_name} scan: \n{self.message}"

def to_dict(self) -> Dict[str, str]:
return {
"category": self.name(),
"description": self.message,
"source": str(self.model.get_source()),
}


class DependencyError(ModelScanScannerError):
@staticmethod
def name() -> str:
return "DEPENDENCY"


class PathError(ErrorBase):
path: Path

def __init__(
self,
message: str,
path: Path,
) -> None:
super().__init__(message)
self.path = path

def __str__(self) -> str:
return f"The following error was raised during scan of file {str(self.path)}: \n{self.message}"

@staticmethod
def name() -> str:
return "PATH"

def to_dict(self) -> Dict[str, str]:
return {
"category": self.name(),
"description": self.message,
"source": str(self.path),
}


class NestedZipError(PathError):
@staticmethod
def name() -> str:
return "NESTED_ZIP"


class PickleGenopsError(ModelScanScannerError):
@staticmethod
def name() -> str:
return "PICKLE_GENOPS"


class JsonDecodeError(ModelScanScannerError):
@staticmethod
def name() -> str:
return "JSON_DECODE"
61 changes: 21 additions & 40 deletions modelscan/modelscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
from datetime import datetime
import zipfile

from modelscan.error import ModelScanError, ErrorCategories
from modelscan.error import (
ModelScanError,
PathError,
ErrorBase,
ModelScanScannerError,
NestedZipError,
)
from modelscan.skip import ModelScanSkipped, SkipCategories
from modelscan.issues import Issues, IssueSeverity
from modelscan.scanners.scan import ScanBase
Expand All @@ -27,7 +33,7 @@ def __init__(
) -> None:
# Output
self._issues = Issues()
self._errors: List[ModelScanError] = []
self._errors: List[ErrorBase] = []
self._init_errors: List[ModelScanError] = []
self._skipped: List[ModelScanSkipped] = []
self._scanned: List[str] = []
Expand All @@ -46,13 +52,7 @@ def _load_middlewares(self) -> None:
)
except MiddlewareImportError as e:
logger.exception(e)
self._init_errors.append(
ModelScanError(
"MiddlewarePipeline",
ErrorCategories.MODEL_SCAN,
f"Error loading middlewares: {e}",
)
)
self._init_errors.append(ModelScanError(f"Error loading middlewares: {e}"))

def _load_scanners(self) -> None:
for scanner_path, scanner_settings in self._settings["scanners"].items():
Expand All @@ -73,23 +73,14 @@ def _load_scanners(self) -> None:
logger.error(f"Error importing scanner {scanner_path}")
self._init_errors.append(
ModelScanError(
scanner_path,
ErrorCategories.MODEL_SCAN,
f"Error importing scanner: {e}",
f"Error importing scanner {scanner_path}: {e}",
)
)

def _iterate_models(self, model_path: Path) -> Generator[Model, None, None]:
if not model_path.exists():
logger.error(f"Path {model_path} does not exist")
self._errors.append(
ModelScanError(
"ModelScan",
ErrorCategories.PATH,
"Path is not valid",
str(model_path),
)
)
self._errors.append(PathError("Path is not valid", model_path))

files = [model_path]
if model_path.is_dir():
Expand All @@ -115,11 +106,9 @@ def _iterate_models(self, model_path: Path) -> Generator[Model, None, None]:
file_name = f"{model.get_source()}:{file_name}"
if _is_zipfile(file_name, data=file_io):
self._errors.append(
ModelScanError(
"ModelScan",
ErrorCategories.NESTED_ZIP,
NestedZipError(
"ModelScan does not support nested zip files.",
file_name,
Path(file_name),
)
)
continue
Expand Down Expand Up @@ -192,11 +181,10 @@ def _scan_source(
f"Error encountered from scanner {scanner.full_name()} with path {str(model.get_source())}: {e}"
)
self._errors.append(
ModelScanError(
ModelScanScannerError(
scanner.full_name(),
ErrorCategories.MODEL_SCAN,
f"Error encountered from scanner {scanner.full_name()}: {e}",
str(model.get_source()),
str(e),
model,
)
)
continue
Expand Down Expand Up @@ -282,12 +270,9 @@ def _generate_results(self) -> Dict[str, Any]:
all_errors = []
if self._errors:
for error in self._errors:
error_information = {}
error_information["category"] = str(error.category.name)
if error.message:
error_information["description"] = error.message
if error.source is not None:
resolved_file = Path(error.source).resolve()
error_information = error.to_dict()
if "source" in error_information:
resolved_file = Path(error_information["source"]).resolve()
error_information["source"] = str(
resolved_file.relative_to(Path(absolute_path))
)
Expand Down Expand Up @@ -345,11 +330,7 @@ def generate_report(self) -> Optional[str]:
except Exception as e:
logger.error(f"Error generating report using {reporting_module}: {e}")
self._errors.append(
ModelScanError(
"ModelScan",
ErrorCategories.MODEL_SCAN,
f"Error generating report using {reporting_module}: {e}",
)
ModelScanError(f"Error generating report using {reporting_module}: {e}")
)

return scan_report
Expand All @@ -359,7 +340,7 @@ def issues(self) -> Issues:
return self._issues

@property
def errors(self) -> List[ModelScanError]:
def errors(self) -> List[ErrorBase]:
return self._errors

@property
Expand Down
3 changes: 1 addition & 2 deletions modelscan/reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from rich import print

from modelscan.modelscan import ModelScan
from modelscan.error import Error
from modelscan.issues import Issues, IssueSeverity
from modelscan.issues import IssueSeverity

logger = logging.getLogger("modelscan")

Expand Down
16 changes: 9 additions & 7 deletions modelscan/scanners/h5/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
except ImportError:
h5py_installed = False

from modelscan.error import ModelScanError, ErrorCategories
from modelscan.error import (
DependencyError,
JsonDecodeError,
)
from modelscan.skip import ModelScanSkipped, SkipCategories
from modelscan.scanners.scan import ScanResults
from modelscan.scanners.saved_model.scan import SavedModelLambdaDetectScan
Expand All @@ -32,10 +35,10 @@ def scan(
return ScanResults(
[],
[
ModelScanError(
DependencyError(
self.name(),
ErrorCategories.DEPENDENCY,
f"To use {self.full_name()}, please install modelscan with h5py extras. `pip install 'modelscan[ h5py ]'` if you are using pip.",
model,
)
],
[],
Expand All @@ -58,11 +61,10 @@ def _scan_keras_h5_file(self, model: Model) -> Optional[ScanResults]:
return ScanResults(
[],
[
ModelScanError(
JsonDecodeError(
self.name(),
ErrorCategories.JSON_DECODE,
f"Not a valid JSON data",
str(model.get_source()),
model,
)
],
[],
Expand Down Expand Up @@ -128,7 +130,7 @@ def handle_binary_dependencies(
self, settings: Optional[Dict[str, Any]] = None
) -> Optional[str]:
if not h5py_installed:
return ErrorCategories.DEPENDENCY.name
return DependencyError.name()
return None

@staticmethod
Expand Down
19 changes: 8 additions & 11 deletions modelscan/scanners/keras/scan.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import json
import zipfile
import logging
from pathlib import Path
from typing import IO, List, Union, Optional
from typing import List, Optional


from modelscan.error import ModelScanError, ErrorCategories
from modelscan.error import DependencyError, ModelScanScannerError, JsonDecodeError
from modelscan.skip import ModelScanSkipped, SkipCategories
from modelscan.scanners.scan import ScanResults
from modelscan.scanners.saved_model.scan import SavedModelLambdaDetectScan
Expand All @@ -25,10 +24,10 @@ def scan(self, model: Model) -> Optional[ScanResults]:
return ScanResults(
[],
[
ModelScanError(
DependencyError(
self.name(),
ErrorCategories.DEPENDENCY,
f"To use {self.full_name()}, please install modelscan with tensorflow extras. `pip install 'modelscan[ tensorflow ]'` if you are using pip.",
model,
)
],
[],
Expand Down Expand Up @@ -64,11 +63,10 @@ def scan(self, model: Model) -> Optional[ScanResults]:
return ScanResults(
[],
[
ModelScanError(
ModelScanScannerError(
self.name(),
ErrorCategories.MODEL_SCAN, # Giving a generic error category as this return is added to pass mypy
f"Unable to scan .keras file", # Not sure if this is a representative message for ModelScanError
str(model.get_source()),
model,
)
],
[],
Expand All @@ -89,11 +87,10 @@ def _scan_keras_config_file(self, model: Model) -> ScanResults:
return ScanResults(
[],
[
ModelScanError(
JsonDecodeError(
self.name(),
ErrorCategories.JSON_DECODE,
f"Not a valid JSON data",
str(model.get_source()),
model,
)
],
[],
Expand Down
Loading

0 comments on commit d5f289c

Please sign in to comment.