diff --git a/modelscan/middlewares/__init__.py b/modelscan/middlewares/__init__.py new file mode 100644 index 0000000..e6dcece --- /dev/null +++ b/modelscan/middlewares/__init__.py @@ -0,0 +1 @@ +from modelscan.middlewares.format_via_extension import FormatViaExtensionMiddleware diff --git a/modelscan/middlewares/format_via_extension.py b/modelscan/middlewares/format_via_extension.py new file mode 100644 index 0000000..2150acf --- /dev/null +++ b/modelscan/middlewares/format_via_extension.py @@ -0,0 +1,17 @@ +from .middleware import MiddlewareBase +from modelscan.model import Model +from typing import Callable + + +class FormatViaExtensionMiddleware(MiddlewareBase): + def __call__(self, model: Model, call_next: Callable[[Model], None]) -> None: + extension = model.get_source().suffix + formats = [ + format + for format, extensions in self._settings["formats"].items() + if extension in extensions + ] + if len(formats) > 0: + model.set_context("formats", model.get_context("formats") or [] + formats) + + call_next(model) diff --git a/modelscan/middlewares/middleware.py b/modelscan/middlewares/middleware.py new file mode 100644 index 0000000..613c895 --- /dev/null +++ b/modelscan/middlewares/middleware.py @@ -0,0 +1,59 @@ +import abc +from modelscan.model import Model +from typing import Callable, Dict, Any, List +import importlib + + +class MiddlewareImportError(Exception): + pass + + +class MiddlewareBase(metaclass=abc.ABCMeta): + _settings: Dict[str, Any] + + def __init__(self, settings: Dict[str, Any]): + self._settings = settings + + @abc.abstractmethod + def __call__( + self, + model: Model, + call_next: Callable[[Model], None], + ) -> None: + raise NotImplementedError + + +class MiddlewarePipeline: + _middlewares: List[MiddlewareBase] + + def __init__(self) -> None: + self._middlewares = [] + + @staticmethod + def from_settings(middleware_settings: Dict[str, Any]) -> "MiddlewarePipeline": + pipeline = MiddlewarePipeline() + + for path, params in middleware_settings.items(): + try: + (modulename, classname) = path.rsplit(".", 1) + imported_module = importlib.import_module( + name=modulename, package=classname + ) + + middleware_class: MiddlewareBase = getattr(imported_module, classname) + pipeline.add_middleware(middleware_class(params)) # type: ignore + except Exception as e: + raise MiddlewareImportError(f"Error importing middleware {path}: {e}") + + return pipeline + + def add_middleware(self, middleware: MiddlewareBase) -> "MiddlewarePipeline": + self._middlewares.append(middleware) + return self + + def run(self, model: Model) -> None: + def runner(model: Model, index: int) -> None: + if index < len(self._middlewares): + self._middlewares[index](model, lambda model: runner(model, index + 1)) + + runner(model, 0) diff --git a/modelscan/model.py b/modelscan/model.py new file mode 100644 index 0000000..43dd611 --- /dev/null +++ b/modelscan/model.py @@ -0,0 +1,54 @@ +from pathlib import Path +from typing import Union, Optional, IO, Dict, Any + + +class ModelDataEmpty(ValueError): + pass + + +class Model: + _source: Path + _stream: Optional[IO[bytes]] + _source_file_used: bool + _context: Dict[str, Any] + + def __init__(self, source: Union[str, Path], stream: Optional[IO[bytes]] = None): + self._source = Path(source) + self._stream = stream + self._source_file_used = False + self._context = {"formats": []} + + def set_context(self, key: str, value: Any) -> None: + self._context[key] = value + + def get_context(self, key: str) -> Any: + return self._context.get(key) + + def open(self) -> "Model": + if self._stream: + return self + + self._stream = open(self._source, "rb") + self._source_file_used = True + + return self + + def close(self) -> None: + # Only close the stream if we opened a file (not for IO[bytes] objects passed in) + if self._stream and self._source_file_used: + self._stream.close() + + def __enter__(self) -> "Model": + return self.open() + + def __exit__(self, exc_type, exc_value, traceback) -> None: # type: ignore + self.close() + + def get_source(self) -> Path: + return self._source + + def get_stream(self) -> IO[bytes]: + if not self._stream: + raise ModelDataEmpty("Model data is empty.") + + return self._stream diff --git a/modelscan/modelscan.py b/modelscan/modelscan.py index 9db1af5..2745990 100644 --- a/modelscan/modelscan.py +++ b/modelscan/modelscan.py @@ -1,19 +1,21 @@ import logging -import zipfile import importlib from modelscan.settings import DEFAULT_SETTINGS from pathlib import Path -from typing import List, Union, Optional, IO, Dict, Any +from typing import List, Union, Dict, Any, Optional, Generator from datetime import datetime +import zipfile from modelscan.error import ModelScanError, ErrorCategories from modelscan.skip import ModelScanSkipped, SkipCategories from modelscan.issues import Issues, IssueSeverity from modelscan.scanners.scan import ScanBase -from modelscan.tools.utils import _is_zipfile from modelscan._version import __version__ +from modelscan.tools.utils import _is_zipfile +from modelscan.model import Model +from modelscan.middlewares.middleware import MiddlewarePipeline, MiddlewareImportError logger = logging.getLogger("modelscan") @@ -35,6 +37,22 @@ def __init__( self._scanners_to_run: List[ScanBase] = [] self._settings: Dict[str, Any] = settings self._load_scanners() + self._load_middlewares() + + def _load_middlewares(self) -> None: + try: + self._middleware_pipeline = MiddlewarePipeline.from_settings( + self._settings["middlewares"] or {} + ) + except MiddlewareImportError as e: + logger.exception(e) + self._init_errors.append( + ModelScanError( + "MiddlewarePipeline", + ErrorCategories.MODEL_SCAN, + f"Error loading middlewares: {e}", + ) + ) def _load_scanners(self) -> None: for scanner_path, scanner_settings in self._settings["scanners"].items(): @@ -61,6 +79,67 @@ def _load_scanners(self) -> None: ) ) + def _iterate_models(self, model_path: Path) -> Generator[Model, None, None]: + if not model_path.exists(): + logger.error(f"Path {model_path} does not exist") + self._errors.append( + ModelScanError( + "ModelScan", + ErrorCategories.PATH, + "Path is not valid", + str(model_path), + ) + ) + + files = [model_path] + if model_path.is_dir(): + logger.debug(f"Path {str(model_path)} is a directory") + files = [f for f in model_path.rglob("*") if Path.is_file(f)] + + for file in files: + with Model(file) as model: + yield model + + if ( + not _is_zipfile(file, model.get_stream()) + and Path(file).suffix + not in self._settings["supported_zip_extensions"] + ): + continue + + try: + with zipfile.ZipFile(model.get_stream(), "r") as zip: + file_names = zip.namelist() + for file_name in file_names: + with zip.open(file_name, "r") as file_io: + file_name = f"{model.get_source()}:{file_name}" + if _is_zipfile(file_name, data=file_io): + self._errors.append( + ModelScanError( + "ModelScan", + ErrorCategories.NESTED_ZIP, + "ModelScan does not support nested zip files.", + file_name, + ) + ) + continue + + yield Model(file_name, file_io) + except zipfile.BadZipFile as e: + logger.debug( + f"Skipping zip file {str(model.get_source())}, due to error", + e, + exc_info=True, + ) + self._skipped.append( + ModelScanSkipped( + "ModelScan", + SkipCategories.BAD_ZIP, + f"Skipping zip file due to error: {e}", + str(model.get_source()), + ) + ) + def scan( self, path: Union[str, Path], @@ -71,141 +150,91 @@ def scan( self._skipped = [] self._scanned = [] self._input_path = str(path) - pathlibPath = Path().cwd() if path == "." else Path(path).absolute() - self._scan_path(Path(pathlibPath)) - return self._generate_results() + pathlib_path = Path().cwd() if path == "." else Path(path).absolute() + model_path = Path(pathlib_path) - def _scan_path( - self, - path: Path, - ) -> None: - if Path.exists(path): - scanned = self._scan_source(path) - if not scanned and path.is_dir(): - self._scan_directory(path) - elif ( - _is_zipfile(path) - or path.suffix in self._settings["supported_zip_extensions"] - ): - self._scan_zip(path) - elif not scanned: - # check if added to skipped already - all_skipped_files = [skipped.source for skipped in self._skipped] - if str(path) not in all_skipped_files: - self._skipped.append( - ModelScanSkipped( - "ModelScan", - SkipCategories.SCAN_NOT_SUPPORTED, - f"Model Scan did not scan file", - str(path), - ) - ) + all_paths: List[Path] = [] + for model in self._iterate_models(model_path): + self._middleware_pipeline.run(model) + self._scan_source(model) + all_paths.append(model.get_source()) - else: - logger.error(f"Error: path {path} is not valid") - self._errors.append( - ModelScanError( - "ModelScan", ErrorCategories.PATH, "Path is not valid", str(path) - ) - ) + if self._skipped: + all_skipped_paths = [skipped.source for skipped in self._skipped] + for skipped in self._skipped: + main_file_path = skipped.source.split(":")[0] + if main_file_path == skipped.source: + continue - def _scan_directory(self, directory_path: Path) -> None: - for path in directory_path.rglob("*"): - if not path.is_dir(): - self._scan_path(path) + # If main container is skipped, we only add its content to skipped but not the file itself + if main_file_path in all_skipped_paths: + self._skipped = [ + item for item in self._skipped if item.source != main_file_path + ] + + continue + + # If main container is scanned, we consider all files to be scanned + self._skipped = [ + item for item in self._skipped if item.source != skipped.source + ] + + return self._generate_results() def _scan_source( self, - source: Union[str, Path], - data: Optional[IO[bytes]] = None, + model: Model, ) -> bool: scanned = False for scan_class in self._scanners_to_run: scanner = scan_class(self._settings) # type: ignore[operator] try: - scan_results = scanner.scan( - source=source, - data=data, - ) + scan_results = scanner.scan(model) except Exception as e: logger.error( - f"Error encountered from scanner {scanner.full_name()} with path {source}: {e}" + f"Error encountered from scanner {scanner.full_name()} with path {str(model.get_source())}: {e}" ) self._errors.append( ModelScanError( scanner.full_name(), ErrorCategories.MODEL_SCAN, f"Error encountered from scanner {scanner.full_name()}: {e}", - f"{source}", + str(model.get_source()), ) ) + continue if scan_results is not None: scanned = True - logger.info(f"Scanning {source} using {scanner.full_name()} model scan") + logger.info( + f"Scanning {model.get_source()} using {scanner.full_name()} model scan" + ) if scan_results.errors: self._errors.extend(scan_results.errors) elif scan_results.issues: - self._scanned.append(str(source)) + self._scanned.append(str(model.get_source())) self._issues.add_issues(scan_results.issues) elif scan_results.skipped: self._skipped.extend(scan_results.skipped) else: - self._scanned.append(str(source)) + self._scanned.append(str(model.get_source())) + + if not scanned: + all_skipped_files = [skipped.source for skipped in self._skipped] + if str(model.get_source()) not in all_skipped_files: + self._skipped.append( + ModelScanSkipped( + "ModelScan", + SkipCategories.SCAN_NOT_SUPPORTED, + f"Model Scan did not scan file", + str(model.get_source()), + ) + ) return scanned - def _scan_zip( - self, source: Union[str, Path], data: Optional[IO[bytes]] = None - ) -> None: - try: - with zipfile.ZipFile(data or source, "r") as zip: - file_names = zip.namelist() - for file_name in file_names: - with zip.open(file_name, "r") as file_io: - scanned = self._scan_source( - source=f"{source}:{file_name}", - data=file_io, - ) - - if not scanned: - if _is_zipfile(file_name, data=file_io): - self._errors.append( - ModelScanError( - "ModelScan", - ErrorCategories.NESTED_ZIP, - "ModelScan does not support nested zip files.", - f"{source}:{file_name}", - ) - ) - - # check if added to skipped already - all_skipped_files = [ - skipped.source for skipped in self._skipped - ] - if f"{source}:{file_name}" not in all_skipped_files: - self._skipped.append( - ModelScanSkipped( - "ModelScan", - SkipCategories.SCAN_NOT_SUPPORTED, - f"Model Scan did not scan file", - f"{source}:{file_name}", - ) - ) - - except zipfile.BadZipFile as e: - logger.debug(f"Skipping zip file {source}, due to error", e, exc_info=True) - self._skipped.append( - ModelScanSkipped( - "ModelScan", - SkipCategories.BAD_ZIP, - f"Skipping zip file due to error: {e}", - f"{source}:{file_name}", - ) - ) - def _generate_results(self) -> Dict[str, Any]: report: Dict[str, Any] = {} diff --git a/modelscan/scanners/h5/scan.py b/modelscan/scanners/h5/scan.py index ee350fe..d8452e6 100644 --- a/modelscan/scanners/h5/scan.py +++ b/modelscan/scanners/h5/scan.py @@ -1,7 +1,6 @@ import json import logging -from pathlib import Path -from typing import IO, List, Union, Optional, Dict, Any +from typing import List, Optional, Dict, Any try: @@ -15,6 +14,7 @@ from modelscan.skip import ModelScanSkipped, SkipCategories from modelscan.scanners.scan import ScanResults from modelscan.scanners.saved_model.scan import SavedModelLambdaDetectScan +from modelscan.model import Model logger = logging.getLogger("modelscan") @@ -22,16 +22,11 @@ class H5LambdaDetectScan(SavedModelLambdaDetectScan): def scan( self, - source: Union[str, Path], - data: Optional[IO[bytes]] = None, + model: Model, ) -> Optional[ScanResults]: - if ( - not Path(source).suffix - in self._settings["scanners"][H5LambdaDetectScan.full_name()][ - "supported_extensions" - ] - ): + if "keras_h5" not in model.get_context("formats"): return None + dep_error = self.handle_binary_dependencies() if dep_error: return ScanResults( @@ -46,33 +41,16 @@ def scan( [], ) - if data: - logger.warning( - f"{self.full_name()} got data bytes. It only support direct file scanning." - ) - return ScanResults( - [], - [], - [ - ModelScanSkipped( - self.name(), - SkipCategories.H5_DATA, - f"{self.full_name()} got data bytes. It only support direct file scanning.", - str(source), - ) - ], - ) - - results = self._scan_keras_h5_file(source) + results = self._scan_keras_h5_file(model) if results: return self.label_results(results) - else: - return None - def _scan_keras_h5_file(self, source: Union[str, Path]) -> Optional[ScanResults]: + return None + + def _scan_keras_h5_file(self, model: Model) -> Optional[ScanResults]: machine_learning_library_name = "Keras" - if self._check_model_config(source): - operators_in_model = self._get_keras_h5_operator_names(source) + if self._check_model_config(model): + operators_in_model = self._get_keras_h5_operator_names(model) if operators_in_model is None: return None @@ -84,7 +62,7 @@ def _scan_keras_h5_file(self, source: Union[str, Path]) -> Optional[ScanResults] self.name(), ErrorCategories.JSON_DECODE, f"Not a valid JSON data", - str(source), + str(model.get_source()), ) ], [], @@ -92,7 +70,7 @@ def _scan_keras_h5_file(self, source: Union[str, Path]) -> Optional[ScanResults] return H5LambdaDetectScan._check_for_unsafe_tf_keras_operator( module_name=machine_learning_library_name, raw_operator=operators_in_model, - source=source, + model=model, unsafe_operators=self._settings["scanners"][ SavedModelLambdaDetectScan.full_name() ]["unsafe_keras_operators"], @@ -106,25 +84,23 @@ def _scan_keras_h5_file(self, source: Union[str, Path]) -> Optional[ScanResults] self.name(), SkipCategories.MODEL_CONFIG, f"Model Config not found", - str(source), + str(model.get_source()), ) ], ) - def _check_model_config(self, source: Union[str, Path]) -> bool: - with h5py.File(source, "r") as model_hdf5: + def _check_model_config(self, model: Model) -> bool: + with h5py.File(model.get_stream(), "r") as model_hdf5: if "model_config" in model_hdf5.attrs.keys(): return True else: - logger.error(f"Model Config not found in: {source}") + logger.error(f"Model Config not found in: {model.get_source()}") return False - def _get_keras_h5_operator_names( - self, source: Union[str, Path] - ) -> Optional[List[Any]]: + def _get_keras_h5_operator_names(self, model: Model) -> Optional[List[Any]]: # Todo: source isn't guaranteed to be a file - with h5py.File(source, "r") as model_hdf5: + with h5py.File(model.get_stream(), "r") as model_hdf5: try: if not "model_config" in model_hdf5.attrs.keys(): return None @@ -138,7 +114,9 @@ def _get_keras_h5_operator_names( layer.get("config", {}).get("function", {}) ) except json.JSONDecodeError as e: - logger.error(f"Not a valid JSON data from source: {source}, error: {e}") + logger.error( + f"Not a valid JSON data from source: {model.get_source()}, error: {e}" + ) return ["JSONDecodeError"] if lambda_layers: diff --git a/modelscan/scanners/keras/scan.py b/modelscan/scanners/keras/scan.py index 72180da..08f897e 100644 --- a/modelscan/scanners/keras/scan.py +++ b/modelscan/scanners/keras/scan.py @@ -2,30 +2,22 @@ import zipfile import logging from pathlib import Path -from typing import IO, List, Union, Optional, Any +from typing import IO, List, Union, Optional from modelscan.error import ModelScanError, ErrorCategories from modelscan.skip import ModelScanSkipped, SkipCategories from modelscan.scanners.scan import ScanResults from modelscan.scanners.saved_model.scan import SavedModelLambdaDetectScan +from modelscan.model import Model logger = logging.getLogger("modelscan") class KerasLambdaDetectScan(SavedModelLambdaDetectScan): - def scan( - self, - source: Union[str, Path], - data: Optional[IO[bytes]] = None, - ) -> Optional[ScanResults]: - if ( - not Path(source).suffix - in self._settings["scanners"][KerasLambdaDetectScan.full_name()][ - "supported_extensions" - ] - ): + def scan(self, model: Model) -> Optional[ScanResults]: + if "keras" not in model.get_context("formats"): return None dep_error = self.handle_binary_dependencies() @@ -43,16 +35,16 @@ def scan( ) try: - with zipfile.ZipFile(data or source, "r") as zip: + with zipfile.ZipFile(model.get_stream(), "r") as zip: file_names = zip.namelist() for file_name in file_names: if file_name == "config.json": with zip.open(file_name, "r") as config_file: + model = Model( + f"{model.get_source()}:{file_name}", config_file + ) return self.label_results( - self._scan_keras_config_file( - source=f"{source}:{file_name}", - config_file=config_file, - ) + self._scan_keras_config_file(model) ) except zipfile.BadZipFile as e: return ScanResults( @@ -63,7 +55,7 @@ def scan( self.name(), SkipCategories.BAD_ZIP, f"Skipping zip file due to error: {e}", - f"{source}:{file_name}", + f"{model.get_source()}:{file_name}", ) ], ) @@ -76,20 +68,18 @@ def scan( self.name(), ErrorCategories.MODEL_SCAN, # Giving a generic error category as this return is added to pass mypy f"Unable to scan .keras file", # Not sure if this is a representative message for ModelScanError - str(source), + str(model.get_source()), ) ], [], ) - def _scan_keras_config_file( - self, source: Union[str, Path], config_file: IO[bytes] - ) -> ScanResults: + def _scan_keras_config_file(self, model: Model) -> ScanResults: machine_learning_library_name = "Keras" # if self._check_json_data(source, config_file): - operators_in_model = self._get_keras_operator_names(source, config_file) + operators_in_model = self._get_keras_operator_names(model) if operators_in_model: if "JSONDecodeError" in operators_in_model: return ScanResults( @@ -99,7 +89,7 @@ def _scan_keras_config_file( self.name(), ErrorCategories.JSON_DECODE, f"Not a valid JSON data", - str(source), + str(model.get_source()), ) ], [], @@ -108,7 +98,7 @@ def _scan_keras_config_file( return KerasLambdaDetectScan._check_for_unsafe_tf_keras_operator( module_name=machine_learning_library_name, raw_operator=operators_in_model, - source=source, + model=model, unsafe_operators=self._settings["scanners"][ SavedModelLambdaDetectScan.full_name() ]["unsafe_keras_operators"], @@ -121,11 +111,9 @@ def _scan_keras_config_file( [], ) - def _get_keras_operator_names( - self, source: Union[str, Path], data: IO[bytes] - ) -> List[str]: + def _get_keras_operator_names(self, model: Model) -> List[str]: try: - model_config_data = json.load(data) + model_config_data = json.load(model.get_stream()) lambda_layers = [ layer.get("config", {}).get("function", {}) @@ -136,7 +124,9 @@ def _get_keras_operator_names( return ["Lambda"] * len(lambda_layers) except json.JSONDecodeError as e: - logger.error(f"Not a valid JSON data from source: {source}, error: {e}") + logger.error( + f"Not a valid JSON data from source: {model.get_source()}, error: {e}" + ) return ["JSONDecodeError"] return [] diff --git a/modelscan/scanners/pickle/scan.py b/modelscan/scanners/pickle/scan.py index feb8272..d138202 100644 --- a/modelscan/scanners/pickle/scan.py +++ b/modelscan/scanners/pickle/scan.py @@ -1,6 +1,5 @@ import logging -from pathlib import Path -from typing import IO, Union, Optional +from typing import Optional from modelscan.scanners.scan import ScanBase, ScanResults from modelscan.tools.utils import _is_zipfile @@ -9,6 +8,7 @@ scan_pickle_bytes, scan_pytorch, ) +from modelscan.model import Model logger = logging.getLogger("modelscan") @@ -16,28 +16,18 @@ class PyTorchUnsafeOpScan(ScanBase): def scan( self, - source: Union[str, Path], - data: Optional[IO[bytes]] = None, + model: Model, ) -> Optional[ScanResults]: - if ( - not Path(source).suffix - in self._settings["scanners"][PyTorchUnsafeOpScan.full_name()][ - "supported_extensions" - ] - ): + if "pytorch" not in model.get_context("formats"): return None - if _is_zipfile(source, data): + if _is_zipfile(model.get_source(), model.get_stream()): return None - if data: - results = scan_pytorch(data=data, source=source, settings=self._settings) - - else: - with open(source, "rb") as file_io: - results = scan_pytorch( - data=file_io, source=source, settings=self._settings - ) + results = scan_pytorch( + model=model, + settings=self._settings, + ) return self.label_results(results) @@ -53,25 +43,15 @@ def full_name() -> str: class NumpyUnsafeOpScan(ScanBase): def scan( self, - source: Union[str, Path], - data: Optional[IO[bytes]] = None, + model: Model, ) -> Optional[ScanResults]: - if ( - not Path(source).suffix - in self._settings["scanners"][NumpyUnsafeOpScan.full_name()][ - "supported_extensions" - ] - ): + if "numpy" not in model.get_context("formats"): return None - if data: - results = scan_numpy(data=data, source=source, settings=self._settings) - - else: - with open(source, "rb") as file_io: - results = scan_numpy( - data=file_io, source=source, settings=self._settings - ) + results = scan_numpy( + model=model, + settings=self._settings, + ) return self.label_results(results) @@ -87,27 +67,15 @@ def full_name() -> str: class PickleUnsafeOpScan(ScanBase): def scan( self, - source: Union[str, Path], - data: Optional[IO[bytes]] = None, + model: Model, ) -> Optional[ScanResults]: - if ( - not Path(source).suffix - in self._settings["scanners"][PickleUnsafeOpScan.full_name()][ - "supported_extensions" - ] - ): + if "pickle" not in model.get_context("formats"): return None - if data: - results = scan_pickle_bytes( - data=data, source=source, settings=self._settings - ) - - else: - with open(source, "rb") as file_io: - results = scan_pickle_bytes( - data=file_io, source=source, settings=self._settings - ) + results = scan_pickle_bytes( + model=model, + settings=self._settings, + ) return self.label_results(results) diff --git a/modelscan/scanners/saved_model/scan.py b/modelscan/scanners/saved_model/scan.py index 35d9d5b..c78bb73 100644 --- a/modelscan/scanners/saved_model/scan.py +++ b/modelscan/scanners/saved_model/scan.py @@ -19,6 +19,7 @@ from modelscan.error import ModelScanError, ErrorCategories from modelscan.issues import Issue, IssueCode, IssueSeverity, OperatorIssueDetails from modelscan.scanners.scan import ScanBase, ScanResults +from modelscan.model import Model logger = logging.getLogger("modelscan") @@ -26,13 +27,9 @@ class SavedModelScan(ScanBase): def scan( self, - source: Union[str, Path], - data: Optional[IO[bytes]] = None, + model: Model, ) -> Optional[ScanResults]: - if ( - not Path(source).suffix - in self._settings["scanners"][self.full_name()]["supported_extensions"] - ): + if "tf_saved_model" not in model.get_context("formats"): return None dep_error = self.handle_binary_dependencies() @@ -49,19 +46,11 @@ def scan( [], ) - if data: - results = self._scan(source, data) + results = self._scan(model) - else: - with open(source, "rb") as file_io: - results = self._scan(source, data=file_io) + return self.label_results(results) if results else None - if results: - return self.label_results(results) - else: - return None - - def _scan(self, source: Union[str, Path], data: IO[bytes]) -> Optional[ScanResults]: + def _scan(self, model: Model) -> Optional[ScanResults]: raise NotImplementedError # This function checks for malicious operators in both Keras and Tensorflow @@ -69,7 +58,7 @@ def _scan(self, source: Union[str, Path], data: IO[bytes]) -> Optional[ScanResul def _check_for_unsafe_tf_keras_operator( module_name: str, raw_operator: List[str], - source: Union[str, Path], + model: Model, unsafe_operators: Dict[str, Any], ) -> ScanResults: issues: List[Issue] = [] @@ -93,7 +82,7 @@ def _check_for_unsafe_tf_keras_operator( details=OperatorIssueDetails( module=module_name, operator=op, - source=source, + source=str(model.get_source()), severity=severity, ), ) @@ -117,44 +106,39 @@ def full_name() -> str: class SavedModelLambdaDetectScan(SavedModelScan): - def _scan(self, source: Union[str, Path], data: IO[bytes]) -> Optional[ScanResults]: - file_name = str(source).split("/")[-1] - if file_name == "keras_metadata.pb": - machine_learning_library_name = "Keras" - operators_in_model = self._get_keras_pb_operator_names( - data=data, source=source - ) - if operators_in_model: - if "JSONDecodeError" in operators_in_model: - return ScanResults( - [], - [ - ModelScanError( - self.name(), - ErrorCategories.JSON_DECODE, - f"Not a valid JSON data", - str(source), - ) - ], - [], - ) + def _scan(self, model: Model) -> Optional[ScanResults]: + file_name = str(model.get_source()).split("/")[-1] + if file_name != "keras_metadata.pb": + return None - return SavedModelScan._check_for_unsafe_tf_keras_operator( - machine_learning_library_name, - operators_in_model, - source, - self._settings["scanners"][self.full_name()]["unsafe_keras_operators"], - ) + machine_learning_library_name = "Keras" + operators_in_model = self._get_keras_pb_operator_names(model) + if operators_in_model: + if "JSONDecodeError" in operators_in_model: + return ScanResults( + [], + [ + ModelScanError( + self.name(), + ErrorCategories.JSON_DECODE, + f"Not a valid JSON data", + str(model.get_source()), + ) + ], + [], + ) - else: - return None + return SavedModelScan._check_for_unsafe_tf_keras_operator( + machine_learning_library_name, + operators_in_model, + model, + self._settings["scanners"][self.full_name()]["unsafe_keras_operators"], + ) @staticmethod - def _get_keras_pb_operator_names( - data: IO[bytes], source: Union[str, Path] - ) -> List[str]: + def _get_keras_pb_operator_names(model: Model) -> List[str]: saved_metadata = SavedMetadata() - saved_metadata.ParseFromString(data.read()) + saved_metadata.ParseFromString(model.get_stream().read()) try: lambda_layers = [ @@ -170,7 +154,9 @@ def _get_keras_pb_operator_names( return ["Lambda"] * len(lambda_layers) except json.JSONDecodeError as e: - logger.error(f"Not a valid JSON data from source: {source}, error: {e}") + logger.error( + f"Not a valid JSON data from source: {str(model.get_source())}, error: {e}" + ) return ["JSONDecodeError"] return [] @@ -181,25 +167,24 @@ def full_name() -> str: class SavedModelTensorflowOpScan(SavedModelScan): - def _scan(self, source: Union[str, Path], data: IO[bytes]) -> Optional[ScanResults]: - file_name = str(source).split("/")[-1] + def _scan(self, model: Model) -> Optional[ScanResults]: + file_name = str(model.get_source()).split("/")[-1] if file_name == "keras_metadata.pb": return None - else: - machine_learning_library_name = "Tensorflow" - operators_in_model = self._get_tensorflow_operator_names(data=data) + machine_learning_library_name = "Tensorflow" + operators_in_model = self._get_tensorflow_operator_names(model) return SavedModelScan._check_for_unsafe_tf_keras_operator( machine_learning_library_name, operators_in_model, - source, + model, self._settings["scanners"][self.full_name()]["unsafe_tf_operators"], ) - def _get_tensorflow_operator_names(self, data: IO[bytes]) -> List[str]: + def _get_tensorflow_operator_names(self, model: Model) -> List[str]: saved_model = SavedModel() - saved_model.ParseFromString(data.read()) + saved_model.ParseFromString(model.get_stream().read()) model_op_names: Set[str] = set() # Iterate over every metagraph in case there is more than one diff --git a/modelscan/scanners/scan.py b/modelscan/scanners/scan.py index 681ef70..89a5981 100644 --- a/modelscan/scanners/scan.py +++ b/modelscan/scanners/scan.py @@ -1,10 +1,10 @@ import abc -from pathlib import Path -from typing import List, Union, Optional, IO, Any, Dict +from typing import List, Optional, Any, Dict from modelscan.error import ModelScanError from modelscan.skip import ModelScanSkipped from modelscan.issues import Issue +from modelscan.model import Model class ScanResults: @@ -43,8 +43,7 @@ def full_name() -> str: @abc.abstractmethod def scan( self, - source: Union[str, Path], - data: Optional[IO[bytes]] = None, + model: Model, ) -> Optional[ScanResults]: raise NotImplementedError diff --git a/modelscan/settings.py b/modelscan/settings.py index 8041ffe..f4bfa9e 100644 --- a/modelscan/settings.py +++ b/modelscan/settings.py @@ -15,22 +15,18 @@ "scanners": { "modelscan.scanners.H5LambdaDetectScan": { "enabled": True, - "supported_extensions": [".h5"], }, "modelscan.scanners.KerasLambdaDetectScan": { "enabled": True, - "supported_extensions": [".keras"], }, "modelscan.scanners.SavedModelLambdaDetectScan": { "enabled": True, - "supported_extensions": [".pb"], "unsafe_keras_operators": { "Lambda": "MEDIUM", }, }, "modelscan.scanners.SavedModelTensorflowOpScan": { "enabled": True, - "supported_extensions": [".pb"], "unsafe_tf_operators": { "ReadFile": "HIGH", "WriteFile": "HIGH", @@ -38,24 +34,34 @@ }, "modelscan.scanners.NumpyUnsafeOpScan": { "enabled": True, - "supported_extensions": [".npy"], }, "modelscan.scanners.PickleUnsafeOpScan": { "enabled": True, - "supported_extensions": [ - ".pkl", - ".pickle", - ".joblib", - ".dill", - ".dat", - ".data", - ], }, "modelscan.scanners.PyTorchUnsafeOpScan": { "enabled": True, - "supported_extensions": [".bin", ".pt", ".pth", ".ckpt"], }, }, + "middlewares": { + "modelscan.middlewares.FormatViaExtensionMiddleware": { + "formats": { + "tf": [".pb"], + "tf_saved_model": [".pb"], + "keras_h5": [".h5"], + "keras": [".keras"], + "numpy": [".npy"], + "pytorch": [".bin", ".pt", ".pth", ".ckpt"], + "pickle": [ + ".pkl", + ".pickle", + ".joblib", + ".dill", + ".dat", + ".data", + ], + } + } + }, "unsafe_globals": { "CRITICAL": { "__builtin__": [ diff --git a/modelscan/tools/picklescanner.py b/modelscan/tools/picklescanner.py index d177288..b6e6078 100644 --- a/modelscan/tools/picklescanner.py +++ b/modelscan/tools/picklescanner.py @@ -1,7 +1,5 @@ import logging import pickletools # nosec -from dataclasses import dataclass -from pathlib import Path from tarfile import TarError from typing import IO, Any, Dict, List, Set, Tuple, Union @@ -11,6 +9,7 @@ from modelscan.skip import ModelScanSkipped, SkipCategories from modelscan.issues import Issue, IssueCode, IssueSeverity, OperatorIssueDetails from modelscan.scanners.scan import ScanResults +from modelscan.model import Model logger = logging.getLogger("modelscan") @@ -117,17 +116,15 @@ def _list_globals( def scan_pickle_bytes( - data: IO[bytes], - source: Union[Path, str], + model: Model, settings: Dict[str, Any], scan_name: str = "pickle", multiple_pickles: bool = True, ) -> ScanResults: """Disassemble a Pickle stream and report issues""" - issues: List[Issue] = [] try: - raw_globals = _list_globals(data, multiple_pickles) + raw_globals = _list_globals(model.get_stream(), multiple_pickles) except GenOpsError as e: return ScanResults( issues, @@ -136,13 +133,13 @@ def scan_pickle_bytes( scan_name, ErrorCategories.PICKLE_GENOPS, f"Parsing error: {e}", - str(source), + str(model.get_source()), ) ], [], ) - logger.debug("Global imports in %s: %s", source, raw_globals) + logger.debug("Global imports in %s: %s", model.get_source(), raw_globals) severities = { "CRITICAL": IssueSeverity.CRITICAL, "HIGH": IssueSeverity.HIGH, @@ -176,7 +173,7 @@ def scan_pickle_bytes( details=OperatorIssueDetails( module=global_module, operator=global_name, - source=source, + source=model.get_source(), severity=severity, ), ) @@ -184,18 +181,16 @@ def scan_pickle_bytes( return ScanResults(issues, [], []) -def scan_numpy( - data: IO[bytes], source: Union[str, Path], settings: Dict[str, Any] -) -> ScanResults: +def scan_numpy(model: Model, settings: Dict[str, Any]) -> ScanResults: scan_name = "numpy" # Code to distinguish from NumPy binary files and pickles. _ZIP_PREFIX = b"PK\x03\x04" _ZIP_SUFFIX = b"PK\x05\x06" # empty zip files start with this N = len(np.lib.format.MAGIC_PREFIX) - magic = data.read(N) + magic = model.get_stream().read(N) # If the file size is less than N, we need to make sure not # to seek past the beginning of the file - data.seek(-min(N, len(magic)), 1) # back-up + model.get_stream().seek(-min(N, len(magic)), 1) # back-up if magic.startswith(_ZIP_PREFIX) or magic.startswith(_ZIP_SUFFIX): # .npz file return ScanResults( @@ -206,40 +201,38 @@ def scan_numpy( scan_name, SkipCategories.NOT_IMPLEMENTED, "Scanning of .npz files is not implemented yet", - str(source), + str(model.get_source()), ) ], ) elif magic == np.lib.format.MAGIC_PREFIX: # .npy file - version = np.lib.format.read_magic(data) # type: ignore[no-untyped-call] + version = np.lib.format.read_magic(model.get_stream()) # type: ignore[no-untyped-call] np.lib.format._check_version(version) # type: ignore[attr-defined] - _, _, dtype = np.lib.format._read_array_header(data, version) # type: ignore[attr-defined] + _, _, dtype = np.lib.format._read_array_header(model.get_stream(), version) # type: ignore[attr-defined] if dtype.hasobject: - return scan_pickle_bytes(data, source, settings, scan_name) + return scan_pickle_bytes(model, settings, scan_name) else: return ScanResults([], [], []) else: - return scan_pickle_bytes(data, source, settings, scan_name) + return scan_pickle_bytes(model, settings, scan_name) -def scan_pytorch( - data: IO[bytes], source: Union[str, Path], settings: Dict[str, Any] -) -> ScanResults: +def scan_pytorch(model: Model, settings: Dict[str, Any]) -> ScanResults: scan_name = "pytorch" - should_read_directly = _should_read_directly(data) - if should_read_directly and data.tell() == 0: + should_read_directly = _should_read_directly(model.get_stream()) + if should_read_directly and model.get_stream().tell() == 0: # try loading from tar try: # TODO: implement loading from tar raise TarError() except TarError: # file does not contain a tar - data.seek(0) + model.get_stream().seek(0) - magic = get_magic_number(data) + magic = get_magic_number(model.get_stream()) if magic != MAGIC_NUMBER: return ScanResults( [], @@ -249,8 +242,9 @@ def scan_pytorch( scan_name, SkipCategories.MAGIC_NUMBER, f"Invalid magic number", - str(source), + str(model.get_source()), ) ], ) - return scan_pickle_bytes(data, source, settings, scan_name, multiple_pickles=False) + + return scan_pickle_bytes(model, settings, scan_name, multiple_pickles=False) diff --git a/tests/test_modelscan.py b/tests/test_modelscan.py index cde5c63..cc9963d 100644 --- a/tests/test_modelscan.py +++ b/tests/test_modelscan.py @@ -34,11 +34,11 @@ ) from modelscan.tools.picklescanner import ( scan_pickle_bytes, - scan_numpy, ) from modelscan.skip import SkipCategories from modelscan.settings import DEFAULT_SETTINGS +from modelscan.model import Model settings: Dict[str, Any] = DEFAULT_SETTINGS @@ -431,12 +431,8 @@ def test_scan_pickle_bytes() -> None: ) ] - assert ( - scan_pickle_bytes( - io.BytesIO(pickle.dumps(Malicious1())), "file.pkl", settings - ).issues - == expected - ) + model = Model("file.pkl", io.BytesIO(pickle.dumps(Malicious1()))) + assert scan_pickle_bytes(model, settings).issues == expected def test_scan_zip(zip_file_path: Any) -> None: @@ -456,7 +452,10 @@ def test_scan_zip(zip_file_path: Any) -> None: ms = ModelScan() results = ms.scan(f"{zip_file_path}/test.zip") assert results["summary"]["scanned"]["scanned_files"] == [f"test.zip:data.pkl"] - assert results["summary"]["skipped"]["skipped_files"] == [] + assert [ + skipped_file["source"] + for skipped_file in results["summary"]["skipped"]["skipped_files"] + ] == ["test.zip"] assert ms.issues.all_issues == expected @@ -478,14 +477,16 @@ def test_scan_pytorch(pytorch_file_path: Any) -> None: f"safe_zip_pytorch.pt:safe_zip_pytorch/data.pkl" ] - assert [ - skipped_file["source"] - for skipped_file in results["summary"]["skipped"]["skipped_files"] - ] == [ + assert set( + [ + skipped_file["source"] + for skipped_file in results["summary"]["skipped"]["skipped_files"] + ] + ) == { "safe_zip_pytorch.pt:safe_zip_pytorch/byteorder", "safe_zip_pytorch.pt:safe_zip_pytorch/version", "safe_zip_pytorch.pt:safe_zip_pytorch/.data/serialization_id", - ] + } assert ms.issues.all_issues == [] assert results["errors"] == [] @@ -514,14 +515,16 @@ def test_scan_pytorch(pytorch_file_path: Any) -> None: assert results["summary"]["scanned"]["scanned_files"] == [ f"unsafe_zip_pytorch.pt:unsafe_zip_pytorch/data.pkl", ] - assert [ - skipped_file["source"] - for skipped_file in results["summary"]["skipped"]["skipped_files"] - ] == [ + assert set( + [ + skipped_file["source"] + for skipped_file in results["summary"]["skipped"]["skipped_files"] + ] + ) == { "unsafe_zip_pytorch.pt:unsafe_zip_pytorch/byteorder", "unsafe_zip_pytorch.pt:unsafe_zip_pytorch/version", "unsafe_zip_pytorch.pt:unsafe_zip_pytorch/.data/serialization_id", - ] + } assert ms.issues.all_issues == expected assert results["errors"] == [] @@ -1239,7 +1242,10 @@ def test_scan_directory_path(file_path: str) -> None: f"benign0_v3.dill", f"benign0_v4.dill", } - assert results["summary"]["skipped"]["skipped_files"] == [] + assert [ + skipped_file["source"] + for skipped_file in results["summary"]["skipped"]["skipped_files"] + ] == ["malicious1.zip"] assert results["errors"] == [] @@ -1287,19 +1293,7 @@ def test_scan_keras(keras_file_path: Any, file_extension: str) -> None: f"safe{file_extension}" ] - if file_extension == ".keras": - assert set( - [ - skipped_file["source"] - for skipped_file in results["summary"]["skipped"]["skipped_files"] - ] - ) == { - f"safe{file_extension}:metadata.json", - f"safe{file_extension}:config.json", - f"safe{file_extension}:model.weights.h5", - } - else: - assert results["summary"]["skipped"]["skipped_files"] == [] + assert results["summary"]["skipped"]["skipped_files"] == [] assert results["errors"] == []