From 667816b68a76b4d8a9c207c1fd836d910f93f50a Mon Sep 17 00:00:00 2001 From: Mehrin Kiani Date: Thu, 22 Feb 2024 12:13:51 -0500 Subject: [PATCH] Added description for skipped files --- modelscan/error.py | 21 +++-- modelscan/modelscan.py | 121 ++++++++++++++++++------- modelscan/scanners/h5/scan.py | 80 ++++++++++------ modelscan/scanners/keras/scan.py | 62 +++++++++---- modelscan/scanners/saved_model/scan.py | 52 ++++++----- modelscan/scanners/scan.py | 16 +++- modelscan/skip.py | 47 ++++++++++ modelscan/tools/picklescanner.py | 50 ++++++++-- tests/test_modelscan.py | 97 +++++++++++--------- 9 files changed, 387 insertions(+), 159 deletions(-) create mode 100644 modelscan/skip.py diff --git a/modelscan/error.py b/modelscan/error.py index 677d81b..fdff016 100644 --- a/modelscan/error.py +++ b/modelscan/error.py @@ -3,12 +3,21 @@ class ErrorCategories(Enum): - MODEL_FILE = 1 - JSON_DATA = 2 - DEPENDENCY = 3 + MODEL_SCAN = 1 + DEPENDENCY = 2 + PATH = 3 + NESTED_ZIP = 4 + PICKLE_GENOPS = 5 + MAGIC_NUMBER = 6 + JSON_DECODE = 7 class Error: + scan_name: str + category: ErrorCategories + message: Optional[str] + source: Optional[str] + def __init__(self) -> None: pass @@ -17,17 +26,15 @@ def __str__(self) -> str: class ModelScanError(Error): - scan_name: str - message: Optional[str] - source: Optional[str] - def __init__( self, scan_name: str, + category: ErrorCategories, message: Optional[str] = None, source: Optional[str] = None, ) -> None: self.scan_name = scan_name + self.category = category self.message = message or "None" self.source = str(source) diff --git a/modelscan/modelscan.py b/modelscan/modelscan.py index 0e22a73..491918c 100644 --- a/modelscan/modelscan.py +++ b/modelscan/modelscan.py @@ -5,10 +5,11 @@ from modelscan.settings import DEFAULT_SETTINGS from pathlib import Path -from typing import List, Union, Optional, IO, Dict, Tuple, Any +from typing import List, Union, Optional, IO, Dict, Any from datetime import datetime -from modelscan.error import Error, ModelScanError +from modelscan.error import ModelScanError, ErrorCategories +from modelscan.skip import ModelScanSkipped, SkipCategories from modelscan.issues import Issues, IssueSeverity from modelscan.scanners.scan import ScanBase from modelscan.tools.utils import _is_zipfile @@ -24,9 +25,9 @@ def __init__( ) -> None: # Output self._issues = Issues() - self._errors: List[Error] = [] - self._init_errors: List[Error] = [] - self._skipped: List[str] = [] + self._errors: List[ModelScanError] = [] + self._init_errors: List[ModelScanError] = [] + self._skipped: List[ModelScanSkipped] = [] self._scanned: List[str] = [] self._input_path: str = "" @@ -54,7 +55,9 @@ def _load_scanners(self) -> None: logger.error(f"Error importing scanner {scanner_path}") self._init_errors.append( ModelScanError( - scanner_path, f"Error importing scanner {scanner_path}: {e}" + scanner_path, + ErrorCategories.MODEL_SCAN, + f"Error importing scanner: {e}", ) ) @@ -86,13 +89,25 @@ def _scan_path( ): self._scan_zip(path) elif not scanned: - self._skipped.append(str(path)) + # check if added to skipped already + all_skipped_files = [skipped.source for skipped in self._skipped] + if str(path) not in all_skipped_files: + self._skipped.append( + ModelScanSkipped( + "ModelScan", + SkipCategories.SCAN_NOT_SUPPORTED, + f"Model Scan did not scan file", + str(path), + ) + ) + else: logger.error(f"Error: path {path} is not valid") self._errors.append( - ModelScanError("ModelScan", f"Path {path} is not valid") + ModelScanError( + "ModelScan", ErrorCategories.PATH, "Path is not valid", str(Path) + ) ) - self._skipped.append(str(path)) def _scan_directory(self, directory_path: Path) -> None: for path in directory_path.rglob("*"): @@ -111,12 +126,18 @@ def _scan_source( source=source, data=data, ) + if scan_results is not None: logger.info(f"Scanning {source} using {scanner.full_name()} model scan") - self._scanned.append(str(source)) - self._issues.add_issues(scan_results.issues) - self._errors.extend(scan_results.errors) - scanned = True + if scan_results.errors: + self._errors.extend(scan_results.errors) + elif scan_results.skipped: + self._skipped.extend(scan_results.skipped) + else: + self._scanned.append(str(source)) + self._issues.add_issues(scan_results.issues) + scanned = True + return scanned def _scan_zip( @@ -131,18 +152,42 @@ def _scan_zip( source=f"{source}:{file_name}", data=file_io, ) + if not scanned: if _is_zipfile(file_name, data=file_io): self._errors.append( ModelScanError( "ModelScan", - f"{source}:{file_name} is a zip file. ModelScan does not support nested zip files.", + ErrorCategories.NESTED_ZIP, + "ModelScan does not support nested zip files.", + f"{source}:{file_name}", ) ) - self._skipped.append(f"{source}:{file_name}") + + # check if added to skipped already + all_skipped_files = [ + skipped.source for skipped in self._skipped + ] + if f"{source}:{file_name}" not in all_skipped_files: + self._skipped.append( + ModelScanSkipped( + "ModelScan", + SkipCategories.SCAN_NOT_SUPPORTED, + f"Model Scan did not scan file", + f"{source}:{file_name}", + ) + ) + except zipfile.BadZipFile as e: logger.debug(f"Skipping zip file {source}, due to error", e, exc_info=True) - self._skipped.append(str(source)) + self._skipped.append( + ModelScanSkipped( + "ModelScan", + SkipCategories.BAD_ZIP, + f"Skipping zip file due to error: {e}", + f"{source}:{file_name}", + ) + ) def _generate_results(self) -> Dict[str, Any]: report: Dict[str, Any] = {} @@ -168,11 +213,7 @@ def _generate_results(self) -> Dict[str, Any]: report["summary"]["absolute_path"] = str(absolute_path) report["summary"]["modelscan_version"] = __version__ report["summary"]["timestamp"] = datetime.now().isoformat() - report["summary"]["skipped"] = {"total_skipped": len(self._skipped)} - report["summary"]["skipped"]["skipped_files"] = [ - str(Path(file_name).relative_to(Path(absolute_path))) - for file_name in self._skipped - ] + report["summary"]["scanned"] = {"total_scanned": len(self._scanned)} report["summary"]["scanned"]["scanned_files"] = [ str(Path(file_name).relative_to(Path(absolute_path))) @@ -190,17 +231,35 @@ def _generate_results(self) -> Dict[str, Any]: all_errors = [] - for err in self._errors: - error = {} - if err.message is not None: - error["description"] = err.message - if hasattr(err, "source"): - error["source"] = str(Path(err.source).relative_to(Path(absolute_path))) - if error: - all_errors.append(error) + for error in self._errors: + error_information = {} + error_information["category"] = str(error.category.name) + if error.message is not None: + error_information["description"] = error.message + if hasattr(error, "source"): + error_information["source"] = str( + Path(str(error.source)).relative_to(Path(absolute_path)) + ) + + all_errors.append(error_information) report["errors"] = all_errors + report["summary"]["skipped"] = {"total_skipped": len(self._skipped)} + + all_skipped_files = [] + + for skipped_file in self._skipped: + skipped_file_information = {} + skipped_file_information["category"] = str(skipped_file.category.name) + skipped_file_information["description"] = str(skipped_file.message) + skipped_file_information["source"] = str( + Path(skipped_file.source).relative_to(Path(absolute_path)) + ) + all_skipped_files.append(skipped_file_information) + + report["summary"]["skipped"]["skipped_files"] = all_skipped_files + return report def is_compatible(self, path: str) -> bool: @@ -222,7 +281,7 @@ def issues(self) -> Issues: return self._issues @property - def errors(self) -> List[Error]: + def errors(self) -> List[ModelScanError]: return self._errors @property @@ -230,5 +289,5 @@ def scanned(self) -> List[str]: return self._scanned @property - def skipped(self) -> List[str]: + def skipped(self) -> List[ModelScanSkipped]: return self._skipped diff --git a/modelscan/scanners/h5/scan.py b/modelscan/scanners/h5/scan.py index a919f4a..2f3f193 100644 --- a/modelscan/scanners/h5/scan.py +++ b/modelscan/scanners/h5/scan.py @@ -12,6 +12,7 @@ h5py_installed = False from modelscan.error import ModelScanError, ErrorCategories +from modelscan.skip import ModelScanSkipped, SkipCategories from modelscan.scanners.scan import ScanResults from modelscan.scanners.saved_model.scan import SavedModelLambdaDetectScan @@ -31,7 +32,6 @@ def scan( ] ): return None - dep_error = self.handle_binary_dependencies() if dep_error: return ScanResults( @@ -39,9 +39,11 @@ def scan( [ ModelScanError( self.name(), + ErrorCategories.DEPENDENCY, f"To use {self.full_name()}, please install modelscan with h5py extras. 'pip install \"modelscan\[h5py]\"' if you are using pip.", ) ], + [], ) if data: @@ -49,12 +51,14 @@ def scan( f"{self.full_name()} got data bytes. It only support direct file scanning." ) return ScanResults( + [], [], [ - ModelScanError( + ModelScanSkipped( self.name(), + SkipCategories.H5_DATA, f"{self.full_name()} got data bytes. It only support direct file scanning.", - source, + str(source), ) ], ) @@ -67,30 +71,53 @@ def scan( def _scan_keras_h5_file(self, source: Union[str, Path]) -> Optional[ScanResults]: machine_learning_library_name = "Keras" - operators_in_model = self._get_keras_h5_operator_names(source) - - if ErrorCategories.MODEL_FILE.name in operators_in_model: - return ScanResults( - [], - [ModelScanError(self.name(), f"Model Config not found", source)], + if self._check_model_config(source): + operators_in_model = self._get_keras_h5_operator_names(source) + if operators_in_model is None: + return None + + if "JSONDecodeError" in operators_in_model: + return ScanResults( + [], + [ + ModelScanError( + self.name(), + ErrorCategories.JSON_DECODE, + f"Not a valid JSON data", + str(source), + ) + ], + [], + ) + return H5LambdaDetectScan._check_for_unsafe_tf_keras_operator( + module_name=machine_learning_library_name, + raw_operator=operators_in_model, + source=source, + unsafe_operators=self._settings["scanners"][ + SavedModelLambdaDetectScan.full_name() + ]["unsafe_keras_operators"], ) - if ErrorCategories.JSON_DATA.name in operators_in_model: + else: return ScanResults( [], - [ModelScanError(self.name(), f"Not a valid JSON data", source)], + [], + [ + ModelScanSkipped( + self.name(), + SkipCategories.MODEL_CONFIG, + f"Model Config not found", + str(source), + ) + ], ) - if operators_in_model is None: - return None - - return H5LambdaDetectScan._check_for_unsafe_tf_keras_operator( - module_name=machine_learning_library_name, - raw_operator=operators_in_model, - source=source, - unsafe_operators=self._settings["scanners"][ - SavedModelLambdaDetectScan.full_name() - ]["unsafe_keras_operators"], - ) + def _check_model_config(self, source: Union[str, Path]) -> bool: + with h5py.File(source, "r") as model_hdf5: + if "model_config" in model_hdf5.attrs.keys(): + return True + else: + logger.error(f"Model Config not found in: {source}") + return False def _get_keras_h5_operator_names( self, source: Union[str, Path] @@ -100,8 +127,7 @@ def _get_keras_h5_operator_names( with h5py.File(source, "r") as model_hdf5: try: if not "model_config" in model_hdf5.attrs.keys(): - logger.error(f"Model Config not found in: {source}") - return [ErrorCategories.MODEL_FILE.name] + return None model_config = json.loads(model_hdf5.attrs.get("model_config", {})) layers = model_config.get("config", {}).get("layers", {}) @@ -113,7 +139,7 @@ def _get_keras_h5_operator_names( ) except json.JSONDecodeError as e: logger.error(f"Not a valid JSON data from source: {source}, error: {e}") - return [ErrorCategories.JSON_DATA.name] + return ["JSONDecodeError"] if lambda_layers: return ["Lambda"] * len(lambda_layers) @@ -122,9 +148,9 @@ def _get_keras_h5_operator_names( def handle_binary_dependencies( self, settings: Optional[Dict[str, Any]] = None - ) -> Optional[List[Any]]: + ) -> Optional[str]: if not h5py_installed: - return [ErrorCategories.DEPENDENCY.name] + return ErrorCategories.DEPENDENCY.name return None @staticmethod diff --git a/modelscan/scanners/keras/scan.py b/modelscan/scanners/keras/scan.py index 8741bd0..d975ebf 100644 --- a/modelscan/scanners/keras/scan.py +++ b/modelscan/scanners/keras/scan.py @@ -6,6 +6,7 @@ from modelscan.error import ModelScanError, ErrorCategories +from modelscan.skip import ModelScanSkipped, SkipCategories from modelscan.scanners.scan import ScanResults from modelscan.scanners.saved_model.scan import SavedModelLambdaDetectScan @@ -34,9 +35,11 @@ def scan( [ ModelScanError( self.name(), + ErrorCategories.DEPENDENCY, f"To use {self.full_name()}, please install modelscan with dependencies.", ) ], + [], ) try: @@ -53,10 +56,14 @@ def scan( ) except zipfile.BadZipFile as e: return ScanResults( + [], [], [ - ModelScanError( - self.name(), f"Skipping zip file due to error: {e}", source + ModelScanSkipped( + self.name(), + SkipCategories.BAD_ZIP, + f"Skipping zip file due to error: {e}", + f"{source}:{file_name}", ) ], ) @@ -67,49 +74,70 @@ def scan( [ ModelScanError( self.name(), + ErrorCategories.MODEL_SCAN, # Giving a generic error category as this return is added to pass mypy f"Unable to scan .keras file", # Not sure if this is a representative message for ModelScanError - source, + str(source), ) ], + [], ) def _scan_keras_config_file( self, source: Union[str, Path], config_file: IO[bytes] ) -> ScanResults: machine_learning_library_name = "Keras" + + # if self._check_json_data(source, config_file): + operators_in_model = self._get_keras_operator_names(source, config_file) - if ErrorCategories.JSON_DATA.name in operators_in_model: + if operators_in_model: + if "JSONDecodeError" in operators_in_model: + return ScanResults( + [], + [ + ModelScanError( + self.name(), + ErrorCategories.JSON_DECODE, + f"Not a valid JSON data", + str(source), + ) + ], + [], + ) + + return KerasLambdaDetectScan._check_for_unsafe_tf_keras_operator( + module_name=machine_learning_library_name, + raw_operator=operators_in_model, + source=source, + unsafe_operators=self._settings["scanners"][ + SavedModelLambdaDetectScan.full_name() + ]["unsafe_keras_operators"], + ) + + else: return ScanResults( [], - [ModelScanError(self.name(), f"Not a valid JSON data", source)], + [], + [], ) - return KerasLambdaDetectScan._check_for_unsafe_tf_keras_operator( - module_name=machine_learning_library_name, - raw_operator=operators_in_model, - source=source, - unsafe_operators=self._settings["scanners"][ - SavedModelLambdaDetectScan.full_name() - ]["unsafe_keras_operators"], - ) def _get_keras_operator_names( self, source: Union[str, Path], data: IO[bytes] - ) -> List[Any]: + ) -> List[str]: try: model_config_data = json.load(data) + lambda_layers = [ layer.get("config", {}).get("function", {}) for layer in model_config_data.get("config", {}).get("layers", {}) if layer.get("class_name", {}) == "Lambda" ] - if lambda_layers: return ["Lambda"] * len(lambda_layers) except json.JSONDecodeError as e: logger.error(f"Not a valid JSON data from source: {source}, error: {e}") - - return [ErrorCategories.JSON_DATA.name] + return ["JSONDecodeError"] return [] diff --git a/modelscan/scanners/saved_model/scan.py b/modelscan/scanners/saved_model/scan.py index 01b8736..0e4cf3a 100644 --- a/modelscan/scanners/saved_model/scan.py +++ b/modelscan/scanners/saved_model/scan.py @@ -42,9 +42,11 @@ def scan( [ ModelScanError( self.name(), + ErrorCategories.DEPENDENCY, f"To use {self.full_name()}, please install modelscan with tensorflow extras. 'pip install \"modelscan\[tensorflow]\"' if you are using pip.", ) ], + [], ) if data: @@ -96,13 +98,13 @@ def _check_for_unsafe_tf_keras_operator( ), ) ) - return ScanResults(issues, []) + return ScanResults(issues, [], []) def handle_binary_dependencies( self, settings: Optional[Dict[str, Any]] = None - ) -> Optional[ErrorCategories]: + ) -> Optional[str]: if not tensorflow_installed: - return [ErrorCategories.DEPENDENCY.name] + return ErrorCategories.DEPENDENCY.name return None @staticmethod @@ -122,29 +124,35 @@ def _scan(self, source: Union[str, Path], data: IO[bytes]) -> Optional[ScanResul operators_in_model = self._get_keras_pb_operator_names( data=data, source=source ) - if ErrorCategories.JSON_DATA.name in operators_in_model: - return ScanResults( - [], - [ModelScanError(self.name(), f"Not a valid JSON data", source)], - ) + if operators_in_model: + if "JSONDecodeError" in operators_in_model: + return ScanResults( + [], + [ + ModelScanError( + self.name(), + ErrorCategories.JSON_DECODE, + f"Not a valid JSON data", + str(source), + ) + ], + [], + ) - if operators_in_model is None: - return None + return SavedModelScan._check_for_unsafe_tf_keras_operator( + machine_learning_library_name, + operators_in_model, + source, + self._settings["scanners"][self.full_name()]["unsafe_keras_operators"], + ) else: return None - return SavedModelScan._check_for_unsafe_tf_keras_operator( - machine_learning_library_name, - operators_in_model, - source, - self._settings["scanners"][self.full_name()]["unsafe_keras_operators"], - ) - @staticmethod def _get_keras_pb_operator_names( data: IO[bytes], source: Union[str, Path] - ) -> List[Any]: + ) -> List[str]: saved_metadata = SavedMetadata() saved_metadata.ParseFromString(data.read()) @@ -158,12 +166,12 @@ def _get_keras_pb_operator_names( ] if layer.get("class_name", {}) == "Lambda" ] + if lambda_layers: + return ["Lambda"] * len(lambda_layers) + except json.JSONDecodeError as e: logger.error(f"Not a valid JSON data from source: {source}, error: {e}") - return [ErrorCategories.JSON_DATA.name] - - if lambda_layers: - return ["Lambda"] * len(lambda_layers) + return ["JSONDecodeError"] return [] diff --git a/modelscan/scanners/scan.py b/modelscan/scanners/scan.py index 681c87c..681ef70 100644 --- a/modelscan/scanners/scan.py +++ b/modelscan/scanners/scan.py @@ -2,17 +2,25 @@ from pathlib import Path from typing import List, Union, Optional, IO, Any, Dict -from modelscan.error import Error, ModelScanError +from modelscan.error import ModelScanError +from modelscan.skip import ModelScanSkipped from modelscan.issues import Issue class ScanResults: issues: List[Issue] - errors: List[Error] + errors: List[ModelScanError] + skipped: List[ModelScanSkipped] - def __init__(self, issues: List[Issue], errors: List[Error]) -> None: + def __init__( + self, + issues: List[Issue], + errors: List[ModelScanError], + skipped: List[ModelScanSkipped], + ) -> None: self.issues = issues self.errors = errors + self.skipped = skipped class ScanBase(metaclass=abc.ABCMeta): @@ -42,7 +50,7 @@ def scan( def handle_binary_dependencies( self, settings: Optional[Dict[str, Any]] = None - ) -> Optional[ModelScanError]: + ) -> Optional[str]: """ Implement this method if the plugin requires a binary dependency. It should perform the following actions: diff --git a/modelscan/skip.py b/modelscan/skip.py new file mode 100644 index 0000000..2dd556e --- /dev/null +++ b/modelscan/skip.py @@ -0,0 +1,47 @@ +import abc +import logging +from enum import Enum +from pathlib import Path +from typing import Any, List, Union, Dict, Optional + +from collections import defaultdict + +logger = logging.getLogger("modelscan") + + +class SkipCategories(Enum): + SCAN_NOT_SUPPORTED = 1 + BAD_ZIP = 2 + MODEL_CONFIG = 3 + H5_DATA = 4 + NOT_IMPLEMENTED = 5 + + +class Skip: + scan_name: str + category: SkipCategories + message: str + source: str + + def __init__(self) -> None: + pass + + def __str__(self) -> str: + raise NotImplementedError() + + +class ModelScanSkipped: + def __init__( + self, + scan_name: str, + category: SkipCategories, + message: str, + source: str, + ) -> None: + self.scan_name = scan_name + self.category = category + self.message = message or "None" + self.source = str(source) + + def __str__(self) -> str: + return f"The following file {self.source} was skipped during a {self.scan_name} scan: \n{self.message}" diff --git a/modelscan/tools/picklescanner.py b/modelscan/tools/picklescanner.py index b44b482..aab7247 100644 --- a/modelscan/tools/picklescanner.py +++ b/modelscan/tools/picklescanner.py @@ -7,7 +7,8 @@ import numpy as np -from modelscan.error import Error, ModelScanError +from modelscan.error import ModelScanError, ErrorCategories +from modelscan.skip import ModelScanSkipped, SkipCategories from modelscan.issues import Issue, IssueCode, IssueSeverity, OperatorIssueDetails from modelscan.scanners.scan import ScanResults @@ -130,7 +131,15 @@ def scan_pickle_bytes( except GenOpsError as e: return ScanResults( issues, - [ModelScanError(scan_name, f"Parsing error: {e}"), source], + [ + ModelScanError( + scan_name, + ErrorCategories.PICKLE_GENOPS, + f"Parsing error: {e}", + str(source), + ) + ], + [], ) logger.debug("Global imports in %s: %s", source, raw_globals) @@ -172,12 +181,13 @@ def scan_pickle_bytes( ), ) ) - return ScanResults(issues, []) + return ScanResults(issues, [], []) def scan_numpy( data: IO[bytes], source: Union[str, Path], settings: Dict[str, Any] ) -> ScanResults: + scan_name = "numpy" # Code to distinguish from NumPy binary files and pickles. _ZIP_PREFIX = b"PK\x03\x04" _ZIP_SUFFIX = b"PK\x05\x06" # empty zip files start with this @@ -188,7 +198,19 @@ def scan_numpy( data.seek(-min(N, len(magic)), 1) # back-up if magic.startswith(_ZIP_PREFIX) or magic.startswith(_ZIP_SUFFIX): # .npz file - raise NotImplementedError("Scanning of .npz files is not implemented yet") + return ScanResults( + [], + [], + [ + ModelScanSkipped( + scan_name, + SkipCategories.NOT_IMPLEMENTED, + "Scanning of .npz files is not implemented yet", + str(source), + ) + ], + ) + # raise NotImplementedError("Scanning of .npz files is not implemented yet") elif magic == np.lib.format.MAGIC_PREFIX: # .npy file version = np.lib.format.read_magic(data) # type: ignore[no-untyped-call] @@ -196,16 +218,17 @@ def scan_numpy( _, _, dtype = np.lib.format._read_array_header(data, version) # type: ignore[attr-defined] if dtype.hasobject: - return scan_pickle_bytes(data, source, settings, "numpy") + return scan_pickle_bytes(data, source, settings, scan_name) else: - return ScanResults([], []) + return ScanResults([], [], []) else: - return scan_pickle_bytes(data, source, settings, "numpy") + return scan_pickle_bytes(data, source, settings, scan_name) def scan_pytorch( data: IO[bytes], source: Union[str, Path], settings: Dict[str, Any] ) -> ScanResults: + scan_name = "pytorch" should_read_directly = _should_read_directly(data) if should_read_directly and data.tell() == 0: # try loading from tar @@ -219,6 +242,15 @@ def scan_pytorch( magic = get_magic_number(data) if magic != MAGIC_NUMBER: return ScanResults( - [], [ModelScanError("pytorch", f"Invalid magic number", source)] + [], + [ + ModelScanError( + scan_name, + ErrorCategories.MAGIC_NUMBER, + f"Invalid magic number", + str(source), + ), + ], + [], ) - return scan_pickle_bytes(data, source, settings, "pytorch", multiple_pickles=False) + return scan_pickle_bytes(data, source, settings, scan_name, multiple_pickles=False) diff --git a/tests/test_modelscan.py b/tests/test_modelscan.py index 7c8831d..2352db0 100644 --- a/tests/test_modelscan.py +++ b/tests/test_modelscan.py @@ -36,12 +36,12 @@ scan_pickle_bytes, scan_numpy, ) + +from modelscan.error import ErrorCategories from modelscan.settings import DEFAULT_SETTINGS settings: Dict[str, Any] = DEFAULT_SETTINGS -from modelscan.scanners.saved_model.scan import SavedModelScan - class Malicious1: def __reduce__(self) -> Any: @@ -463,8 +463,14 @@ def test_scan_zip(zip_file_path: Any) -> None: def test_scan_pytorch(pytorch_file_path: Any) -> None: ms = ModelScan() results = ms.scan(Path(f"{pytorch_file_path}/bad_pytorch.pt")) - assert results["summary"]["scanned"]["scanned_files"] == [f"bad_pytorch.pt"] - assert results["summary"]["skipped"]["skipped_files"] == [] + + assert results["errors"] == [ + { + "category": ErrorCategories.MAGIC_NUMBER.name, + "description": f"Invalid magic number", + "source": f"bad_pytorch.pt", + } + ] assert ms.issues.all_issues == [] assert [error.scan_name for error in ms.errors] == ["pytorch"] # type: ignore[attr-defined] @@ -472,7 +478,11 @@ def test_scan_pytorch(pytorch_file_path: Any) -> None: assert results["summary"]["scanned"]["scanned_files"] == [ f"safe_zip_pytorch.pt:safe_zip_pytorch/data.pkl" ] - assert results["summary"]["skipped"]["skipped_files"] == [ + + assert [ + skipped_file["source"] + for skipped_file in results["summary"]["skipped"]["skipped_files"] + ] == [ "safe_zip_pytorch.pt:safe_zip_pytorch/byteorder", "safe_zip_pytorch.pt:safe_zip_pytorch/version", "safe_zip_pytorch.pt:safe_zip_pytorch/.data/serialization_id", @@ -505,7 +515,10 @@ def test_scan_pytorch(pytorch_file_path: Any) -> None: assert results["summary"]["scanned"]["scanned_files"] == [ f"unsafe_zip_pytorch.pt:unsafe_zip_pytorch/data.pkl", ] - assert results["summary"]["skipped"]["skipped_files"] == [ + assert [ + skipped_file["source"] + for skipped_file in results["summary"]["skipped"]["skipped_files"] + ] == [ "unsafe_zip_pytorch.pt:unsafe_zip_pytorch/byteorder", "unsafe_zip_pytorch.pt:unsafe_zip_pytorch/version", "unsafe_zip_pytorch.pt:unsafe_zip_pytorch/.data/serialization_id", @@ -1251,40 +1264,33 @@ def test_scan_keras(keras_file_path: Any, file_extension: str) -> None: results = ms.scan(Path(safe_filename)) assert ms.issues.all_issues == [] + if file_extension == ".pb": - assert set(results["summary"]["scanned"]["scanned_files"]) == { + assert results["summary"]["scanned"]["scanned_files"] == [ f"fingerprint.pb", f"keras_metadata.pb", f"saved_model.pb", - } - assert set(results["summary"]["skipped"]["skipped_files"]) == { - f"variables/variables.data-00000-of-00001", + ] + + assert [ + skipped_file["source"] + for skipped_file in results["summary"]["skipped"]["skipped_files"] + ] == [ + "variables/variables.data-00000-of-00001", f"variables/variables.index", - } + ] assert results["errors"] == [] - elif file_extension == ".keras": - assert results["summary"]["scanned"]["scanned_files"] == [ - f"safe{file_extension}", - f"safe{file_extension}:model.weights.h5", - ] - assert results["summary"]["skipped"]["skipped_files"] == [ - "safe.keras:metadata.json", - "safe.keras:config.json", - ] - assert results["errors"] == [ - { - "description": "modelscan.scanners.H5LambdaDetectScan got data bytes. It only support direct file scanning.", - "source": "safe.keras:model.weights.h5", - } - ] else: assert results["summary"]["scanned"]["scanned_files"] == [ f"safe{file_extension}" ] if file_extension == ".keras": - assert results["summary"]["skipped"]["skipped_files"] == [ + assert [ + skipped_file["source"] + for skipped_file in results["summary"]["skipped"]["skipped_files"] + ] == [ f"safe{file_extension}:metadata.json", f"safe{file_extension}:config.json", f"safe{file_extension}:model.weights.h5", @@ -1320,13 +1326,9 @@ def test_scan_keras(keras_file_path: Any, file_extension: str) -> None: ), ] results = ms.scan(Path(f"{keras_file_path_parent_dir}/unsafe{file_extension}")) + assert ms.issues.all_issues == expected - assert results["errors"] == [ - { - "description": "modelscan.scanners.H5LambdaDetectScan got data bytes. It only support direct file scanning.", - "source": "unsafe.keras:model.weights.h5", - } - ] + elif file_extension == ".pb": file_name = "keras_metadata.pb" unsafe_filename = f"{unsafe_saved_model_dir}" @@ -1355,15 +1357,18 @@ def test_scan_keras(keras_file_path: Any, file_extension: str) -> None: results = ms.scan(Path(f"{unsafe_saved_model_dir}")) assert ms.issues.all_issues == expected assert results["errors"] == [] - assert set(results["summary"]["scanned"]["scanned_files"]) == { + assert results["summary"]["scanned"]["scanned_files"] == [ f"fingerprint.pb", f"keras_metadata.pb", f"saved_model.pb", - } - assert set(results["summary"]["skipped"]["skipped_files"]) == { + ] + assert [ + skipped_file["source"] + for skipped_file in results["summary"]["skipped"]["skipped_files"] + ] == [ f"variables/variables.data-00000-of-00001", f"variables/variables.index", - } + ] else: unsafe_filename = f"{keras_file_path_parent_dir}/unsafe{file_extension}" expected = [ @@ -1414,15 +1419,18 @@ def test_scan_tensorflow(tensorflow_file_path: Any) -> None: ms = ModelScan() results = ms.scan(Path(f"{safe_tensorflow_model_dir}")) assert ms.issues.all_issues == [] - assert set(results["summary"]["scanned"]["scanned_files"]) == { + assert results["summary"]["scanned"]["scanned_files"] == [ f"fingerprint.pb", f"keras_metadata.pb", f"saved_model.pb", - } - assert set(results["summary"]["skipped"]["skipped_files"]) == { + ] + assert [ + skipped_file["source"] + for skipped_file in results["summary"]["skipped"]["skipped_files"] + ] == [ f"variables/variables.data-00000-of-00001", f"variables/variables.index", - } + ] assert results["errors"] == [] file_name = "saved_model.pb" @@ -1456,7 +1464,12 @@ def test_scan_tensorflow(tensorflow_file_path: Any) -> None: f"keras_metadata.pb", f"saved_model.pb", } - assert set(results["summary"]["skipped"]["skipped_files"]) == { + assert set( + [ + skipped_file["source"] + for skipped_file in results["summary"]["skipped"]["skipped_files"] + ] + ) == { f"variables/variables.data-00000-of-00001", f"variables/variables.index", }