diff --git a/modelscan/modelscan.py b/modelscan/modelscan.py index 2ab397e..00295e0 100644 --- a/modelscan/modelscan.py +++ b/modelscan/modelscan.py @@ -2,7 +2,7 @@ import zipfile import importlib -from modelscan.settings import DEFAULT_SCANNERS, DEFAULT_SETTINGS +from modelscan.settings import DEFAULT_SETTINGS from pathlib import Path from typing import List, Union, Optional, IO, Dict, Tuple, Any @@ -20,7 +20,6 @@ class ModelScan: def __init__( self, - scanners_to_load: List[str] = DEFAULT_SCANNERS, settings: Dict[str, Any] = DEFAULT_SETTINGS, ) -> None: # Output @@ -34,32 +33,30 @@ def __init__( # Scanners self._scanners_to_run: List[ScanBase] = [] self._settings: Dict[str, Any] = settings - self._load_scanners(scanners_to_load) - - def _load_scanners(self, scanners_to_load: List[str]) -> None: - scanner_classes: Dict[str, ScanBase] = {} - for scanner_path in scanners_to_load: - try: - (modulename, classname) = scanner_path.rsplit(".", 1) - imported_module = importlib.import_module( - name=modulename, package=classname - ) - scanner_class = getattr(imported_module, classname) - scanner_classes[scanner_path] = scanner_class - except Exception as e: - logger.error(f"Error importing scanner {scanner_path}") - self._init_errors.append( - ModelScanError( - scanner_path, f"Error importing scanner {scanner_path}: {e}" + self._load_scanners() + + def _load_scanners(self) -> None: + for scanner_path, scanner_settings in self._settings["scanners"].items(): + if ( + "enabled" in scanner_settings.keys() + and self._settings["scanners"][scanner_path]["enabled"] + ): + try: + (modulename, classname) = scanner_path.rsplit(".", 1) + imported_module = importlib.import_module( + name=modulename, package=classname ) - ) - scanners_to_run: List[ScanBase] = [] - for scanner_class, scanner in scanner_classes.items(): - is_enabled: bool = self._settings["scanners"][scanner_class]["enabled"] - if is_enabled: - scanners_to_run.append(scanner) - self._scanners_to_run = scanners_to_run + scanner_class: ScanBase = getattr(imported_module, classname) + self._scanners_to_run.append(scanner_class) + + except Exception as e: + logger.error(f"Error importing scanner {scanner_path}") + self._init_errors.append( + ModelScanError( + scanner_path, f"Error importing scanner {scanner_path}: {e}" + ) + ) def scan( self, @@ -109,7 +106,7 @@ def _scan_source( ) -> bool: scanned = False for scan_class in self._scanners_to_run: - scanner = scan_class(self._settings["scanners"]) # type: ignore[operator] + scanner = scan_class(self._settings) # type: ignore[operator] scan_results = scanner.scan( source=source, data=data, diff --git a/modelscan/scanners/h5/scan.py b/modelscan/scanners/h5/scan.py index e992966..8855ed7 100644 --- a/modelscan/scanners/h5/scan.py +++ b/modelscan/scanners/h5/scan.py @@ -25,7 +25,7 @@ def scan( ) -> Optional[ScanResults]: if ( not Path(source).suffix - in self._settings[H5Scan.full_name()]["supported_extensions"] + in self._settings["scanners"][H5Scan.full_name()]["supported_extensions"] ): return None diff --git a/modelscan/scanners/keras/scan.py b/modelscan/scanners/keras/scan.py index cf8cbe7..f2ece66 100644 --- a/modelscan/scanners/keras/scan.py +++ b/modelscan/scanners/keras/scan.py @@ -21,7 +21,7 @@ def scan( ) -> Optional[ScanResults]: if ( not Path(source).suffix - in self._settings[KerasScan.full_name()]["supported_extensions"] + in self._settings["scanners"][KerasScan.full_name()]["supported_extensions"] ): return None diff --git a/modelscan/scanners/pickle/scan.py b/modelscan/scanners/pickle/scan.py index 7ac1efd..033667d 100644 --- a/modelscan/scanners/pickle/scan.py +++ b/modelscan/scanners/pickle/scan.py @@ -20,7 +20,9 @@ def scan( ) -> Optional[ScanResults]: if ( not Path(source).suffix - in self._settings[PyTorchScan.full_name()]["supported_extensions"] + in self._settings["scanners"][PyTorchScan.full_name()][ + "supported_extensions" + ] ): return None @@ -52,7 +54,7 @@ def scan( ) -> Optional[ScanResults]: if ( not Path(source).suffix - in self._settings[NumpyScan.full_name()]["supported_extensions"] + in self._settings["scanners"][NumpyScan.full_name()]["supported_extensions"] ): return None @@ -81,7 +83,9 @@ def scan( ) -> Optional[ScanResults]: if ( not Path(source).suffix - in self._settings[PickleScan.full_name()]["supported_extensions"] + in self._settings["scanners"][PickleScan.full_name()][ + "supported_extensions" + ] ): return None diff --git a/modelscan/scanners/saved_model/scan.py b/modelscan/scanners/saved_model/scan.py index e928ff8..0b2dd4c 100644 --- a/modelscan/scanners/saved_model/scan.py +++ b/modelscan/scanners/saved_model/scan.py @@ -31,7 +31,9 @@ def scan( ) -> Optional[ScanResults]: if ( not Path(source).suffix - in self._settings[SavedModelScan.full_name()]["supported_extensions"] + in self._settings["scanners"][SavedModelScan.full_name()][ + "supported_extensions" + ] ): return None @@ -115,9 +117,9 @@ def _check_for_unsafe_tf_keras_operator( source: Union[str, Path], settings: Dict[str, Any], ) -> ScanResults: - unsafe_operators: Dict[str, Any] = settings[SavedModelScan.full_name()][ - "unsafe_tf_keras_operators" - ] + unsafe_operators: Dict[str, Any] = settings["scanners"][ + SavedModelScan.full_name() + ]["unsafe_tf_keras_operators"] issues: List[Issue] = [] all_operators = tensorflow.raw_ops.__dict__.keys() diff --git a/modelscan/settings.py b/modelscan/settings.py index 8ced0a1..00b008d 100644 --- a/modelscan/settings.py +++ b/modelscan/settings.py @@ -1,15 +1,8 @@ import tomlkit -DEFAULT_SCANNERS = [ - "modelscan.scanners.H5Scan", - "modelscan.scanners.KerasScan", - "modelscan.scanners.SavedModelScan", - "modelscan.scanners.NumpyScan", - "modelscan.scanners.PickleScan", - "modelscan.scanners.PyTorchScan", -] from typing import Any + DEFAULT_SETTINGS = { "supported_zip_extensions": [".zip", ".npz"], "scanners": { @@ -49,50 +42,50 @@ "enabled": True, "supported_extensions": [".bin", ".pt", ".pth", ".ckpt"], }, - "unsafe_globals": { - "CRITICAL": { - "__builtin__": [ - "eval", - "compile", - "getattr", - "apply", - "exec", - "open", - "breakpoint", - "__import__", - ], # Pickle versions 0, 1, 2 have those function under '__builtin__' - "builtins": [ - "eval", - "compile", - "getattr", - "apply", - "exec", - "open", - "breakpoint", - "__import__", - ], # Pickle versions 3, 4 have those function under 'builtins' - "runpy": "*", - "os": "*", - "nt": "*", # Alias for 'os' on Windows. Includes os.system() - "posix": "*", # Alias for 'os' on Linux. Includes os.system() - "socket": "*", - "subprocess": "*", - "sys": "*", - "operator": [ - "attrgetter", # Ex of code execution: operator.attrgetter("system")(__import__("os"))("echo pwned") - ], - "pty": "*", - "pickle": "*", - }, - "HIGH": { - "webbrowser": "*", # Includes webbrowser.open() - "httplib": "*", # Includes http.client.HTTPSConnection() - "requests.api": "*", - "aiohttp.client": "*", - }, - "MEDIUM": {}, - "LOW": {}, + }, + "unsafe_globals": { + "CRITICAL": { + "__builtin__": [ + "eval", + "compile", + "getattr", + "apply", + "exec", + "open", + "breakpoint", + "__import__", + ], # Pickle versions 0, 1, 2 have those function under '__builtin__' + "builtins": [ + "eval", + "compile", + "getattr", + "apply", + "exec", + "open", + "breakpoint", + "__import__", + ], # Pickle versions 3, 4 have those function under 'builtins' + "runpy": "*", + "os": "*", + "nt": "*", # Alias for 'os' on Windows. Includes os.system() + "posix": "*", # Alias for 'os' on Linux. Includes os.system() + "socket": "*", + "subprocess": "*", + "sys": "*", + "operator": [ + "attrgetter", # Ex of code execution: operator.attrgetter("system")(__import__("os"))("echo pwned") + ], + "pty": "*", + "pickle": "*", + }, + "HIGH": { + "webbrowser": "*", # Includes webbrowser.open() + "httplib": "*", # Includes http.client.HTTPSConnection() + "requests.api": "*", + "aiohttp.client": "*", }, + "MEDIUM": {}, + "LOW": {}, "reporting_module": { "module": "modelscan.reports.ConsoleReport", "settings": {}, diff --git a/tests/test_modelscan.py b/tests/test_modelscan.py index 968b28a..99f2fd4 100644 --- a/tests/test_modelscan.py +++ b/tests/test_modelscan.py @@ -32,7 +32,7 @@ ) from modelscan.settings import DEFAULT_SETTINGS -settings: Dict[str, Any] = DEFAULT_SETTINGS["scanners"] # type: ignore[assignment] +settings: Dict[str, Any] = DEFAULT_SETTINGS class Malicious1: