diff --git a/modelscan/issues.py b/modelscan/issues.py index bd2e911..ba88930 100644 --- a/modelscan/issues.py +++ b/modelscan/issues.py @@ -142,3 +142,6 @@ def output_json(self) -> Dict[str, str]: "source": f"{str(self.source)}", "scanner": f"{self.scanner}", } + + def __repr__(self) -> str: + return f"" diff --git a/modelscan/modelscan.py b/modelscan/modelscan.py index e223565..4a5681b 100644 --- a/modelscan/modelscan.py +++ b/modelscan/modelscan.py @@ -25,6 +25,7 @@ def __init__( # Output self._issues = Issues() self._errors: List[Error] = [] + self._init_errors: List[Error] = [] self._skipped: List[str] = [] self._scanned: List[str] = [] self._input_path: str = "" @@ -51,7 +52,7 @@ def _load_scanners(self) -> None: scanner_classes[scanner_path] = scanner_class except Exception as e: logger.error(f"Error importing scanner {scanner_path}") - self._errors.append( + self._init_errors.append( ModelScanError( scanner_path, f"Error importing scanner {scanner_path}: {e}" ) @@ -59,15 +60,9 @@ def _load_scanners(self) -> None: scanners_to_run: List[ScanBase] = [] for scanner_class, scanner in scanner_classes.items(): - dep_error = scanner.handle_binary_dependencies() - if dep_error: - logger.info( - f"Skipping {scanner.full_name()} as it is missing dependencies" - ) - self._errors.append(dep_error) - else: + is_enabled: bool = self._settings["scanners"][scanner_class]["enabled"] + if is_enabled: scanners_to_run.append(scanner) - self._scanners_to_run = scanners_to_run def scan( @@ -76,6 +71,7 @@ def scan( ) -> Dict[str, Any]: self._issues = Issues() self._errors = [] + self._errors.extend(self._init_errors) self._skipped = [] self._scanned = [] self._input_path = str(path) diff --git a/modelscan/scanners/h5/scan.py b/modelscan/scanners/h5/scan.py index 8b0f6dc..8855ed7 100644 --- a/modelscan/scanners/h5/scan.py +++ b/modelscan/scanners/h5/scan.py @@ -29,6 +29,10 @@ def scan( ): return None + dep_error = self.handle_binary_dependencies() + if dep_error: + return ScanResults([], [dep_error]) + if data: logger.warning( "H5 scanner got data bytes. It only support direct file scanning." diff --git a/modelscan/scanners/keras/scan.py b/modelscan/scanners/keras/scan.py index f6f28b0..f2ece66 100644 --- a/modelscan/scanners/keras/scan.py +++ b/modelscan/scanners/keras/scan.py @@ -25,6 +25,10 @@ def scan( ): return None + dep_error = self.handle_binary_dependencies() + if dep_error: + return ScanResults([], [dep_error]) + try: with zipfile.ZipFile(data or source, "r") as zip: file_names = zip.namelist() diff --git a/modelscan/scanners/saved_model/scan.py b/modelscan/scanners/saved_model/scan.py index b713b62..0b2dd4c 100644 --- a/modelscan/scanners/saved_model/scan.py +++ b/modelscan/scanners/saved_model/scan.py @@ -37,6 +37,10 @@ def scan( ): return None + dep_error = self.handle_binary_dependencies() + if dep_error: + return ScanResults([], [dep_error]) + if data: results = self._scan(source, data) diff --git a/modelscan/settings.py b/modelscan/settings.py index fb3ddfc..9f92cce 100644 --- a/modelscan/settings.py +++ b/modelscan/settings.py @@ -53,6 +53,7 @@ "exec", "open", "breakpoint", + "__import__", ], # Pickle versions 0, 1, 2 have those function under '__builtin__' "builtins": [ "eval", @@ -62,6 +63,7 @@ "exec", "open", "breakpoint", + "__import__", ], # Pickle versions 3, 4 have those function under 'builtins' "runpy": "*", "os": "*", @@ -70,7 +72,11 @@ "socket": "*", "subprocess": "*", "sys": "*", - "operator": "attrgetter", # Ex of code execution: operator.attrgetter("system")(__import__("os"))("echo pwned") + "operator": [ + "attrgetter", # Ex of code execution: operator.attrgetter("system")(__import__("os"))("echo pwned") + ], + "pty": "*", + "pickle": "*", }, "HIGH": { "webbrowser": "*", # Includes webbrowser.open() diff --git a/modelscan/tools/picklescanner.py b/modelscan/tools/picklescanner.py index 0dbe6d2..98c7f41 100644 --- a/modelscan/tools/picklescanner.py +++ b/modelscan/tools/picklescanner.py @@ -134,45 +134,41 @@ def scan_pickle_bytes( ) logger.debug("Global imports in %s: %s", source, raw_globals) + severities = { + "CRITICAL": IssueSeverity.CRITICAL, + "HIGH": IssueSeverity.HIGH, + "MEDIUM": IssueSeverity.MEDIUM, + "LOW": IssueSeverity.LOW, + } for rg in raw_globals: global_module, global_name, severity = rg[0], rg[1], None - unsafe_critical_filter = settings["unsafe_globals"]["CRITICAL"].get( - global_module - ) - unsafe_high_filter = settings["unsafe_globals"]["HIGH"].get(global_module) - unsafe_medium_filter = settings["unsafe_globals"]["MEDIUM"].get(global_module) - unsafe_low_filter = settings["unsafe_globals"]["LOW"].get(global_module) - if unsafe_critical_filter is not None and ( - unsafe_critical_filter == "*" or global_name in unsafe_critical_filter - ): - severity = IssueSeverity.CRITICAL - - elif unsafe_high_filter is not None and ( - unsafe_high_filter == "*" or global_name in unsafe_high_filter - ): - severity = IssueSeverity.HIGH - elif unsafe_medium_filter is not None and ( - unsafe_medium_filter == "*" or global_name in unsafe_medium_filter - ): - severity = IssueSeverity.MEDIUM - elif unsafe_low_filter is not None and ( - unsafe_low_filter == "*" or global_name in unsafe_low_filter - ): - severity = IssueSeverity.LOW - elif "unknown" in global_module or "unknown" in global_name: - severity = IssueSeverity.MEDIUM - else: - continue - issues.append( - Issue( - code=IssueCode.UNSAFE_OPERATOR, - severity=severity, - details=OperatorIssueDetails( - module=global_module, operator=global_name, source=source - ), + for severity_name in severities: + if global_module not in settings["unsafe_globals"][severity_name]: + continue + filter = settings["unsafe_globals"][severity_name][global_module] + if filter == "*": + severity = severities[severity_name] + break + for filter_value in filter: + if filter_value in global_name: + severity = severities[severity_name] + break + else: + continue + break + if "unknown" in global_module or "unknown" in global_name: + severity = IssueSeverity.CRITICAL # we must assume it is RCE + if severity is not None: + issues.append( + Issue( + code=IssueCode.UNSAFE_OPERATOR, + severity=severity, + details=OperatorIssueDetails( + module=global_module, operator=global_name, source=source + ), + ) ) - ) return ScanResults(issues, []) diff --git a/tests/test_modelscan.py b/tests/test_modelscan.py index 4ef08ce..99f2fd4 100644 --- a/tests/test_modelscan.py +++ b/tests/test_modelscan.py @@ -80,6 +80,52 @@ def __reduce__(self) -> Any: return sys.exit, (0,) +def malicious12_gen() -> bytes: + p = pickle.PROTO + b"\x05" + + # stack = [pickle.loads] + p += pickle.GLOBAL + b"pickle\nloads\n" + + # stack = [pickle.loads, p2] + p2 = ( + pickle.PROTO + + b"\x05" + + pickle.GLOBAL + + b"os\nsystem\n" + + pickle.UNICODE + + b"echo pwned!!!\n" + + pickle.TUPLE1 + + pickle.REDUCE + + pickle.STOP + ) + p += pickle.BINBYTES + len(p2).to_bytes(4, "little") + p2 + + # stack = [pickle.loads, (p2,)] + p += pickle.TUPLE1 + + # stack = [pickle.loads(p2)] + p += pickle.REDUCE + + # return None + p += pickle.POP + p += pickle.NONE + + p += pickle.STOP + return p + + +def malicious13_gen() -> bytes: + p = pickle.PROTO + b"\x05" + + p += pickle.GLOBAL + b"builtins\neval.__call__\n" + p += pickle.UNICODE + b'__import__("os").system("echo pwned!!!")\n' + p += pickle.TUPLE1 + p += pickle.REDUCE + + p += pickle.STOP + return p + + def initialize_pickle_file(path: str, obj: Any, version: int) -> None: if not os.path.exists(path): with open(path, "wb") as file: @@ -209,6 +255,10 @@ def file_path(tmp_path_factory: Any) -> Any: ) initialize_data_file(f"{tmp}/data/malicious10.pkl", malicious10_pickle_bytes) + initialize_data_file(f"{tmp}/data/malicious12.pkl", malicious12_gen()) + + initialize_data_file(f"{tmp}/data/malicious13.pkl", malicious13_gen()) + return tmp @@ -643,6 +693,30 @@ def test_scan_pickle_operators(file_path: Any) -> None: malicious11 = ModelScan() malicious11.scan(Path(f"{file_path}/data/malicious11.pkl")) assert malicious11.issues.all_issues == expected_malicious11 + expected_malicious12 = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "pickle", "loads", f"{file_path}/data/malicious12.pkl" + ), + ) + ] + malicious12 = ModelScan() + malicious12.scan(Path(f"{file_path}/data/malicious12.pkl")) + assert malicious12.issues.all_issues == expected_malicious12 + expected_malicious13 = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "builtins", "eval.__call__", f"{file_path}/data/malicious13.pkl" + ), + ) + ] + malicious13 = ModelScan() + malicious13.scan(Path(f"{file_path}/data/malicious13.pkl")) + assert malicious13.issues.all_issues == expected_malicious13 def test_scan_directory_path(file_path: str) -> None: @@ -807,6 +881,20 @@ def test_scan_directory_path(file_path: str) -> None: IssueSeverity.CRITICAL, OperatorIssueDetails("os", "system", f"{file_path}/data/malicious11.pkl"), ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "pickle", "loads", f"{file_path}/data/malicious12.pkl" + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "builtins", "eval.__call__", f"{file_path}/data/malicious13.pkl" + ), + ), } ms = ModelScan() p = Path(f"{file_path}/data/")