From 4291fbcd50fdf177bcf4fac138cd71c8d3fa8254 Mon Sep 17 00:00:00 2001 From: Sam Washko Date: Wed, 10 Jan 2024 17:31:32 -0800 Subject: [PATCH] get scanners from settings, improve settings schema --- modelscan/modelscan.py | 61 ++++++++++--------- modelscan/scanners/h5/scan.py | 2 +- modelscan/scanners/keras/scan.py | 2 +- modelscan/scanners/pickle/scan.py | 10 +++- modelscan/scanners/saved_model/scan.py | 10 ++-- modelscan/settings.py | 83 ++++++++++++-------------- tests/test_modelscan.py | 2 +- 7 files changed, 85 insertions(+), 85 deletions(-) diff --git a/modelscan/modelscan.py b/modelscan/modelscan.py index f459119..e223565 100644 --- a/modelscan/modelscan.py +++ b/modelscan/modelscan.py @@ -2,7 +2,7 @@ import zipfile import importlib -from modelscan.settings import DEFAULT_SCANNERS, DEFAULT_SETTINGS +from modelscan.settings import DEFAULT_SETTINGS from pathlib import Path from typing import List, Union, Optional, IO, Dict, Tuple, Any @@ -20,7 +20,6 @@ class ModelScan: def __init__( self, - scanners_to_load: List[str] = DEFAULT_SCANNERS, settings: Dict[str, Any] = DEFAULT_SETTINGS, ) -> None: # Output @@ -33,38 +32,42 @@ def __init__( # Scanners self._scanners_to_run: List[ScanBase] = [] self._settings: Dict[str, Any] = settings - self._load_scanners(scanners_to_load) + self._load_scanners() - def _load_scanners(self, scanners_to_load: List[str]) -> None: + def _load_scanners(self) -> None: scanner_classes: Dict[str, ScanBase] = {} - for scanner_path in scanners_to_load: - try: - (modulename, classname) = scanner_path.rsplit(".", 1) - imported_module = importlib.import_module( - name=modulename, package=classname - ) - scanner_class = getattr(imported_module, classname) - scanner_classes[scanner_path] = scanner_class - except Exception as e: - logger.error(f"Error importing scanner {scanner_path}") - self._errors.append( - ModelScanError( - scanner_path, f"Error importing scanner {scanner_path}: {e}" + + for scanner_path, scanner_settings in self._settings["scanners"].items(): + if ( + "enabled" in scanner_settings.keys() + and self._settings["scanners"][scanner_path]["enabled"] + ): + try: + (modulename, classname) = scanner_path.rsplit(".", 1) + imported_module = importlib.import_module( + name=modulename, package=classname + ) + scanner_class = getattr(imported_module, classname) + scanner_classes[scanner_path] = scanner_class + except Exception as e: + logger.error(f"Error importing scanner {scanner_path}") + self._errors.append( + ModelScanError( + scanner_path, f"Error importing scanner {scanner_path}: {e}" + ) ) - ) scanners_to_run: List[ScanBase] = [] for scanner_class, scanner in scanner_classes.items(): - is_enabled: bool = self._settings["scanners"][scanner_class]["enabled"] - if is_enabled: - dep_error = scanner.handle_binary_dependencies() - if dep_error: - logger.info( - f"Skipping {scanner.full_name()} as it is missing dependencies" - ) - self._errors.append(dep_error) - else: - scanners_to_run.append(scanner) + dep_error = scanner.handle_binary_dependencies() + if dep_error: + logger.info( + f"Skipping {scanner.full_name()} as it is missing dependencies" + ) + self._errors.append(dep_error) + else: + scanners_to_run.append(scanner) + self._scanners_to_run = scanners_to_run def scan( @@ -114,7 +117,7 @@ def _scan_source( ) -> bool: scanned = False for scan_class in self._scanners_to_run: - scanner = scan_class(self._settings["scanners"]) # type: ignore[operator] + scanner = scan_class(self._settings) # type: ignore[operator] scan_results = scanner.scan( source=source, data=data, diff --git a/modelscan/scanners/h5/scan.py b/modelscan/scanners/h5/scan.py index ceac067..8b0f6dc 100644 --- a/modelscan/scanners/h5/scan.py +++ b/modelscan/scanners/h5/scan.py @@ -25,7 +25,7 @@ def scan( ) -> Optional[ScanResults]: if ( not Path(source).suffix - in self._settings[H5Scan.full_name()]["supported_extensions"] + in self._settings["scanners"][H5Scan.full_name()]["supported_extensions"] ): return None diff --git a/modelscan/scanners/keras/scan.py b/modelscan/scanners/keras/scan.py index 8c13002..f6f28b0 100644 --- a/modelscan/scanners/keras/scan.py +++ b/modelscan/scanners/keras/scan.py @@ -21,7 +21,7 @@ def scan( ) -> Optional[ScanResults]: if ( not Path(source).suffix - in self._settings[KerasScan.full_name()]["supported_extensions"] + in self._settings["scanners"][KerasScan.full_name()]["supported_extensions"] ): return None diff --git a/modelscan/scanners/pickle/scan.py b/modelscan/scanners/pickle/scan.py index 7ac1efd..033667d 100644 --- a/modelscan/scanners/pickle/scan.py +++ b/modelscan/scanners/pickle/scan.py @@ -20,7 +20,9 @@ def scan( ) -> Optional[ScanResults]: if ( not Path(source).suffix - in self._settings[PyTorchScan.full_name()]["supported_extensions"] + in self._settings["scanners"][PyTorchScan.full_name()][ + "supported_extensions" + ] ): return None @@ -52,7 +54,7 @@ def scan( ) -> Optional[ScanResults]: if ( not Path(source).suffix - in self._settings[NumpyScan.full_name()]["supported_extensions"] + in self._settings["scanners"][NumpyScan.full_name()]["supported_extensions"] ): return None @@ -81,7 +83,9 @@ def scan( ) -> Optional[ScanResults]: if ( not Path(source).suffix - in self._settings[PickleScan.full_name()]["supported_extensions"] + in self._settings["scanners"][PickleScan.full_name()][ + "supported_extensions" + ] ): return None diff --git a/modelscan/scanners/saved_model/scan.py b/modelscan/scanners/saved_model/scan.py index 8fa3ee4..b713b62 100644 --- a/modelscan/scanners/saved_model/scan.py +++ b/modelscan/scanners/saved_model/scan.py @@ -31,7 +31,9 @@ def scan( ) -> Optional[ScanResults]: if ( not Path(source).suffix - in self._settings[SavedModelScan.full_name()]["supported_extensions"] + in self._settings["scanners"][SavedModelScan.full_name()][ + "supported_extensions" + ] ): return None @@ -111,9 +113,9 @@ def _check_for_unsafe_tf_keras_operator( source: Union[str, Path], settings: Dict[str, Any], ) -> ScanResults: - unsafe_operators: Dict[str, Any] = settings[SavedModelScan.full_name()][ - "unsafe_tf_keras_operators" - ] + unsafe_operators: Dict[str, Any] = settings["scanners"][ + SavedModelScan.full_name() + ]["unsafe_tf_keras_operators"] issues: List[Issue] = [] all_operators = tensorflow.raw_ops.__dict__.keys() diff --git a/modelscan/settings.py b/modelscan/settings.py index dfb2b6d..fb3ddfc 100644 --- a/modelscan/settings.py +++ b/modelscan/settings.py @@ -3,15 +3,6 @@ from typing import Any -DEFAULT_SCANNERS = [ - "modelscan.scanners.H5Scan", - "modelscan.scanners.KerasScan", - "modelscan.scanners.SavedModelScan", - "modelscan.scanners.NumpyScan", - "modelscan.scanners.PickleScan", - "modelscan.scanners.PyTorchScan", -] - DEFAULT_SETTINGS = { "supported_zip_extensions": [".zip", ".npz"], "scanners": { @@ -51,44 +42,44 @@ "enabled": True, "supported_extensions": [".bin", ".pt", ".pth", ".ckpt"], }, - "unsafe_globals": { - "CRITICAL": { - "__builtin__": [ - "eval", - "compile", - "getattr", - "apply", - "exec", - "open", - "breakpoint", - ], # Pickle versions 0, 1, 2 have those function under '__builtin__' - "builtins": [ - "eval", - "compile", - "getattr", - "apply", - "exec", - "open", - "breakpoint", - ], # Pickle versions 3, 4 have those function under 'builtins' - "runpy": "*", - "os": "*", - "nt": "*", # Alias for 'os' on Windows. Includes os.system() - "posix": "*", # Alias for 'os' on Linux. Includes os.system() - "socket": "*", - "subprocess": "*", - "sys": "*", - "operator": "attrgetter", # Ex of code execution: operator.attrgetter("system")(__import__("os"))("echo pwned") - }, - "HIGH": { - "webbrowser": "*", # Includes webbrowser.open() - "httplib": "*", # Includes http.client.HTTPSConnection() - "requests.api": "*", - "aiohttp.client": "*", - }, - "MEDIUM": {}, - "LOW": {}, + }, + "unsafe_globals": { + "CRITICAL": { + "__builtin__": [ + "eval", + "compile", + "getattr", + "apply", + "exec", + "open", + "breakpoint", + ], # Pickle versions 0, 1, 2 have those function under '__builtin__' + "builtins": [ + "eval", + "compile", + "getattr", + "apply", + "exec", + "open", + "breakpoint", + ], # Pickle versions 3, 4 have those function under 'builtins' + "runpy": "*", + "os": "*", + "nt": "*", # Alias for 'os' on Windows. Includes os.system() + "posix": "*", # Alias for 'os' on Linux. Includes os.system() + "socket": "*", + "subprocess": "*", + "sys": "*", + "operator": "attrgetter", # Ex of code execution: operator.attrgetter("system")(__import__("os"))("echo pwned") + }, + "HIGH": { + "webbrowser": "*", # Includes webbrowser.open() + "httplib": "*", # Includes http.client.HTTPSConnection() + "requests.api": "*", + "aiohttp.client": "*", }, + "MEDIUM": {}, + "LOW": {}, }, } diff --git a/tests/test_modelscan.py b/tests/test_modelscan.py index 209015d..4ef08ce 100644 --- a/tests/test_modelscan.py +++ b/tests/test_modelscan.py @@ -32,7 +32,7 @@ ) from modelscan.settings import DEFAULT_SETTINGS -settings: Dict[str, Any] = DEFAULT_SETTINGS["scanners"] # type: ignore[assignment] +settings: Dict[str, Any] = DEFAULT_SETTINGS class Malicious1: