From 19d17d0aeaba1e9655f7d874e42940ec995818d4 Mon Sep 17 00:00:00 2001 From: asofter Date: Tue, 26 Mar 2024 14:32:23 +0100 Subject: [PATCH] * refactoring of errors --- modelscan/error.py | 119 +++++++++++++++++++------ modelscan/modelscan.py | 61 +++++-------- modelscan/reports.py | 3 +- modelscan/scanners/h5/scan.py | 16 ++-- modelscan/scanners/keras/scan.py | 19 ++-- modelscan/scanners/saved_model/scan.py | 16 ++-- modelscan/scanners/scan.py | 6 +- modelscan/tools/picklescanner.py | 7 +- 8 files changed, 147 insertions(+), 100 deletions(-) diff --git a/modelscan/error.py b/modelscan/error.py index 800a547..7e227c4 100644 --- a/modelscan/error.py +++ b/modelscan/error.py @@ -1,44 +1,111 @@ -from typing import Optional from enum import Enum +from modelscan.model import Model +import abc +from pathlib import Path +from typing import Dict -class ErrorCategories(Enum): - MODEL_SCAN = 1 - DEPENDENCY = 2 - PATH = 3 - NESTED_ZIP = 4 - PICKLE_GENOPS = 5 - JSON_DECODE = 6 - - -class Error: - scan_name: str - category: ErrorCategories +class ErrorBase(metaclass=abc.ABCMeta): message: str - source: Optional[str] - def __init__(self) -> None: - pass + def __init__(self, message: str) -> None: + self.message = message + @abc.abstractmethod def __str__(self) -> str: raise NotImplementedError() + @staticmethod + @abc.abstractmethod + def name() -> str: + raise NotImplementedError + + def to_dict(self) -> Dict[str, str]: + return { + "category": self.name(), + "description": self.message, + } + + +class ModelScanError(ErrorBase): + def __str__(self) -> str: + return f"The following error was raised: \n{self.message}" + + @staticmethod + def name() -> str: + return "MODEL_SCAN" + + +class ModelScanScannerError(ModelScanError): + scan_name: str + model: Model -class ModelScanError(Error): def __init__( self, scan_name: str, - category: ErrorCategories, message: str, - source: Optional[str] = None, + model: Model, ) -> None: + super().__init__(message) self.scan_name = scan_name - self.category = category - self.message = message - self.source = source + self.model = model def __str__(self) -> str: - if self.source: - return f"The following error was raised during a {self.scan_name} scan of file {self.source}: \n{self.message}" - else: - return f"The following error was raised during a {self.scan_name} scan: \n{self.message}" + return f"The following error was raised during a {self.scan_name} scan: \n{self.message}" + + def to_dict(self) -> Dict[str, str]: + return { + "category": self.name(), + "description": self.message, + "source": str(self.model.get_source()), + } + + +class DependencyError(ModelScanScannerError): + @staticmethod + def name() -> str: + return "DEPENDENCY" + + +class PathError(ErrorBase): + path: Path + + def __init__( + self, + message: str, + path: Path, + ) -> None: + super().__init__(message) + self.path = path + + def __str__(self) -> str: + return f"The following error was raised during scan of file {str(self.path)}: \n{self.message}" + + @staticmethod + def name() -> str: + return "PATH" + + def to_dict(self) -> Dict[str, str]: + return { + "category": self.name(), + "description": self.message, + "source": str(self.path), + } + + +class NestedZipError(PathError): + @staticmethod + def name() -> str: + return "NESTED_ZIP" + + +class PickleGenopsError(ModelScanScannerError): + @staticmethod + def name() -> str: + return "PICKLE_GENOPS" + + +class JsonDecodeError(ModelScanScannerError): + @staticmethod + def name() -> str: + return "JSON_DECODE" diff --git a/modelscan/modelscan.py b/modelscan/modelscan.py index a8faa86..f222069 100644 --- a/modelscan/modelscan.py +++ b/modelscan/modelscan.py @@ -8,7 +8,13 @@ from datetime import datetime import zipfile -from modelscan.error import ModelScanError, ErrorCategories +from modelscan.error import ( + ModelScanError, + PathError, + ErrorBase, + ModelScanScannerError, + NestedZipError, +) from modelscan.skip import ModelScanSkipped, SkipCategories from modelscan.issues import Issues, IssueSeverity from modelscan.scanners.scan import ScanBase @@ -27,7 +33,7 @@ def __init__( ) -> None: # Output self._issues = Issues() - self._errors: List[ModelScanError] = [] + self._errors: List[ErrorBase] = [] self._init_errors: List[ModelScanError] = [] self._skipped: List[ModelScanSkipped] = [] self._scanned: List[str] = [] @@ -46,13 +52,7 @@ def _load_middlewares(self) -> None: ) except MiddlewareImportError as e: logger.exception(e) - self._init_errors.append( - ModelScanError( - "MiddlewarePipeline", - ErrorCategories.MODEL_SCAN, - f"Error loading middlewares: {e}", - ) - ) + self._init_errors.append(ModelScanError(f"Error loading middlewares: {e}")) def _load_scanners(self) -> None: for scanner_path, scanner_settings in self._settings["scanners"].items(): @@ -73,23 +73,14 @@ def _load_scanners(self) -> None: logger.error(f"Error importing scanner {scanner_path}") self._init_errors.append( ModelScanError( - scanner_path, - ErrorCategories.MODEL_SCAN, - f"Error importing scanner: {e}", + f"Error importing scanner {scanner_path}: {e}", ) ) def _iterate_models(self, model_path: Path) -> Generator[Model, None, None]: if not model_path.exists(): logger.error(f"Path {model_path} does not exist") - self._errors.append( - ModelScanError( - "ModelScan", - ErrorCategories.PATH, - "Path is not valid", - str(model_path), - ) - ) + self._errors.append(PathError("Path is not valid", model_path)) files = [model_path] if model_path.is_dir(): @@ -115,11 +106,9 @@ def _iterate_models(self, model_path: Path) -> Generator[Model, None, None]: file_name = f"{model.get_source()}:{file_name}" if _is_zipfile(file_name, data=file_io): self._errors.append( - ModelScanError( - "ModelScan", - ErrorCategories.NESTED_ZIP, + NestedZipError( "ModelScan does not support nested zip files.", - file_name, + Path(file_name), ) ) continue @@ -192,11 +181,10 @@ def _scan_source( f"Error encountered from scanner {scanner.full_name()} with path {str(model.get_source())}: {e}" ) self._errors.append( - ModelScanError( + ModelScanScannerError( scanner.full_name(), - ErrorCategories.MODEL_SCAN, - f"Error encountered from scanner {scanner.full_name()}: {e}", - str(model.get_source()), + str(e), + model, ) ) continue @@ -282,12 +270,9 @@ def _generate_results(self) -> Dict[str, Any]: all_errors = [] if self._errors: for error in self._errors: - error_information = {} - error_information["category"] = str(error.category.name) - if error.message: - error_information["description"] = error.message - if error.source is not None: - resolved_file = Path(error.source).resolve() + error_information = error.to_dict() + if "source" in error_information: + resolved_file = Path(error_information["source"]).resolve() error_information["source"] = str( resolved_file.relative_to(Path(absolute_path)) ) @@ -345,11 +330,7 @@ def generate_report(self) -> Optional[str]: except Exception as e: logger.error(f"Error generating report using {reporting_module}: {e}") self._errors.append( - ModelScanError( - "ModelScan", - ErrorCategories.MODEL_SCAN, - f"Error generating report using {reporting_module}: {e}", - ) + ModelScanError(f"Error generating report using {reporting_module}: {e}") ) return scan_report @@ -359,7 +340,7 @@ def issues(self) -> Issues: return self._issues @property - def errors(self) -> List[ModelScanError]: + def errors(self) -> List[ErrorBase]: return self._errors @property diff --git a/modelscan/reports.py b/modelscan/reports.py index 1fbf246..9dacb0e 100644 --- a/modelscan/reports.py +++ b/modelscan/reports.py @@ -6,8 +6,7 @@ from rich import print from modelscan.modelscan import ModelScan -from modelscan.error import Error -from modelscan.issues import Issues, IssueSeverity +from modelscan.issues import IssueSeverity logger = logging.getLogger("modelscan") diff --git a/modelscan/scanners/h5/scan.py b/modelscan/scanners/h5/scan.py index 3e81cb0..f95ff8c 100644 --- a/modelscan/scanners/h5/scan.py +++ b/modelscan/scanners/h5/scan.py @@ -10,7 +10,10 @@ except ImportError: h5py_installed = False -from modelscan.error import ModelScanError, ErrorCategories +from modelscan.error import ( + DependencyError, + JsonDecodeError, +) from modelscan.skip import ModelScanSkipped, SkipCategories from modelscan.scanners.scan import ScanResults from modelscan.scanners.saved_model.scan import SavedModelLambdaDetectScan @@ -32,10 +35,10 @@ def scan( return ScanResults( [], [ - ModelScanError( + DependencyError( self.name(), - ErrorCategories.DEPENDENCY, f"To use {self.full_name()}, please install modelscan with h5py extras. `pip install 'modelscan[ h5py ]'` if you are using pip.", + model, ) ], [], @@ -58,11 +61,10 @@ def _scan_keras_h5_file(self, model: Model) -> Optional[ScanResults]: return ScanResults( [], [ - ModelScanError( + JsonDecodeError( self.name(), - ErrorCategories.JSON_DECODE, f"Not a valid JSON data", - str(model.get_source()), + model, ) ], [], @@ -128,7 +130,7 @@ def handle_binary_dependencies( self, settings: Optional[Dict[str, Any]] = None ) -> Optional[str]: if not h5py_installed: - return ErrorCategories.DEPENDENCY.name + return DependencyError.name() return None @staticmethod diff --git a/modelscan/scanners/keras/scan.py b/modelscan/scanners/keras/scan.py index 8e6eb12..50acd1a 100644 --- a/modelscan/scanners/keras/scan.py +++ b/modelscan/scanners/keras/scan.py @@ -1,11 +1,10 @@ import json import zipfile import logging -from pathlib import Path -from typing import IO, List, Union, Optional +from typing import List, Optional -from modelscan.error import ModelScanError, ErrorCategories +from modelscan.error import DependencyError, ModelScanScannerError, JsonDecodeError from modelscan.skip import ModelScanSkipped, SkipCategories from modelscan.scanners.scan import ScanResults from modelscan.scanners.saved_model.scan import SavedModelLambdaDetectScan @@ -25,10 +24,10 @@ def scan(self, model: Model) -> Optional[ScanResults]: return ScanResults( [], [ - ModelScanError( + DependencyError( self.name(), - ErrorCategories.DEPENDENCY, f"To use {self.full_name()}, please install modelscan with tensorflow extras. `pip install 'modelscan[ tensorflow ]'` if you are using pip.", + model, ) ], [], @@ -64,11 +63,10 @@ def scan(self, model: Model) -> Optional[ScanResults]: return ScanResults( [], [ - ModelScanError( + ModelScanScannerError( self.name(), - ErrorCategories.MODEL_SCAN, # Giving a generic error category as this return is added to pass mypy f"Unable to scan .keras file", # Not sure if this is a representative message for ModelScanError - str(model.get_source()), + model, ) ], [], @@ -89,11 +87,10 @@ def _scan_keras_config_file(self, model: Model) -> ScanResults: return ScanResults( [], [ - ModelScanError( + JsonDecodeError( self.name(), - ErrorCategories.JSON_DECODE, f"Not a valid JSON data", - str(model.get_source()), + model, ) ], [], diff --git a/modelscan/scanners/saved_model/scan.py b/modelscan/scanners/saved_model/scan.py index c78bb73..58d1414 100644 --- a/modelscan/scanners/saved_model/scan.py +++ b/modelscan/scanners/saved_model/scan.py @@ -16,7 +16,10 @@ tensorflow_installed = False -from modelscan.error import ModelScanError, ErrorCategories +from modelscan.error import ( + DependencyError, + JsonDecodeError, +) from modelscan.issues import Issue, IssueCode, IssueSeverity, OperatorIssueDetails from modelscan.scanners.scan import ScanBase, ScanResults from modelscan.model import Model @@ -37,10 +40,10 @@ def scan( return ScanResults( [], [ - ModelScanError( + DependencyError( self.name(), - ErrorCategories.DEPENDENCY, f"To use {self.full_name()}, please install modelscan with tensorflow extras. `pip install 'modelscan[ tensorflow ]'` if you are using pip.", + model, ) ], [], @@ -93,7 +96,7 @@ def handle_binary_dependencies( self, settings: Optional[Dict[str, Any]] = None ) -> Optional[str]: if not tensorflow_installed: - return ErrorCategories.DEPENDENCY.name + return DependencyError.name() return None @staticmethod @@ -118,11 +121,10 @@ def _scan(self, model: Model) -> Optional[ScanResults]: return ScanResults( [], [ - ModelScanError( + JsonDecodeError( self.name(), - ErrorCategories.JSON_DECODE, f"Not a valid JSON data", - str(model.get_source()), + model, ) ], [], diff --git a/modelscan/scanners/scan.py b/modelscan/scanners/scan.py index 89a5981..a49811a 100644 --- a/modelscan/scanners/scan.py +++ b/modelscan/scanners/scan.py @@ -1,7 +1,7 @@ import abc from typing import List, Optional, Any, Dict -from modelscan.error import ModelScanError +from modelscan.error import ErrorBase from modelscan.skip import ModelScanSkipped from modelscan.issues import Issue from modelscan.model import Model @@ -9,13 +9,13 @@ class ScanResults: issues: List[Issue] - errors: List[ModelScanError] + errors: List[ErrorBase] skipped: List[ModelScanSkipped] def __init__( self, issues: List[Issue], - errors: List[ModelScanError], + errors: List[ErrorBase], skipped: List[ModelScanSkipped], ) -> None: self.issues = issues diff --git a/modelscan/tools/picklescanner.py b/modelscan/tools/picklescanner.py index b6e6078..4f1ffe5 100644 --- a/modelscan/tools/picklescanner.py +++ b/modelscan/tools/picklescanner.py @@ -5,7 +5,7 @@ import numpy as np -from modelscan.error import ModelScanError, ErrorCategories +from modelscan.error import ModelScanError, PickleGenopsError from modelscan.skip import ModelScanSkipped, SkipCategories from modelscan.issues import Issue, IssueCode, IssueSeverity, OperatorIssueDetails from modelscan.scanners.scan import ScanResults @@ -129,11 +129,10 @@ def scan_pickle_bytes( return ScanResults( issues, [ - ModelScanError( + PickleGenopsError( scan_name, - ErrorCategories.PICKLE_GENOPS, f"Parsing error: {e}", - str(model.get_source()), + model, ) ], [],