diff --git a/modelscan/tools/picklescanner.py b/modelscan/tools/picklescanner.py index 2a96c31..0dbe6d2 100644 --- a/modelscan/tools/picklescanner.py +++ b/modelscan/tools/picklescanner.py @@ -52,7 +52,7 @@ def _list_globals( ) -> Set[Tuple[str, str]]: globals: Set[Any] = set() - memo: Dict[int, str] = {} + memo: Dict[Union[int, str], str] = {} # Scan the data for pickle buffers, stopping when parsing fails or stops making progress last_byte = b"dummy" while last_byte != b"": @@ -72,10 +72,11 @@ def _list_globals( op_name = op[0].name op_value: str = op[1] - if op_name in ["MEMOIZE", "PUT", "BINPUT", "LONG_BINPUT"] and n > 0: + if op_name == "MEMOIZE" and n > 0: memo[len(memo)] = ops[n - 1][1] - - if op_name in ["GLOBAL", "INST"]: + elif op_name in ["PUT", "BINPUT", "LONG_BINPUT"] and n > 0: + memo[op_value] = ops[n - 1][1] + elif op_name in ("GLOBAL", "INST"): globals.add(tuple(op_value.split(" ", 1))) elif op_name == "STACK_GLOBAL": values: List[str] = [] diff --git a/tests/test_modelscan.py b/tests/test_modelscan.py index 3464dd4..209015d 100644 --- a/tests/test_modelscan.py +++ b/tests/test_modelscan.py @@ -168,6 +168,36 @@ def file_path(tmp_path_factory: Any) -> Any: initialize_pickle_file(f"{tmp}/data/malicious8.pkl", Malicious7(), 4) initialize_pickle_file(f"{tmp}/data/malicious9.pkl", Malicious8(), 4) + # Malicious Pickle from Capture-the-Flag challenge 'Misc/Safe Pickle' at https://imaginaryctf.org/Challenges + # GitHub Issue: https://github.com/mmaitre314/picklescan/issues/22 + initialize_data_file( + f"{tmp}/data/malicious11.pkl", + b"".join( + [ + pickle.UNICODE + b"os\n", + pickle.PUT + b"2\n", + pickle.POP, + pickle.UNICODE + b"system\n", + pickle.PUT + b"3\n", + pickle.POP, + pickle.UNICODE + b"torch\n", + pickle.PUT + b"0\n", + pickle.POP, + pickle.UNICODE + b"LongStorage\n", + pickle.PUT + b"1\n", + pickle.POP, + pickle.GET + b"2\n", + pickle.GET + b"3\n", + pickle.STACK_GLOBAL, + pickle.MARK, + pickle.UNICODE + b"cat flag.txt\n", + pickle.TUPLE, + pickle.REDUCE, + pickle.STOP, + ] + ), + ) + initialize_zip_file( f"{tmp}/data/malicious1.zip", "data.pkl", @@ -603,6 +633,17 @@ def test_scan_pickle_operators(file_path: Any) -> None: malicious10.scan(Path(f"{file_path}/data/malicious10.pkl")) assert malicious10.issues.all_issues == expected_malicious10 + expected_malicious11 = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails("os", "system", f"{file_path}/data/malicious11.pkl"), + ) + ] + malicious11 = ModelScan() + malicious11.scan(Path(f"{file_path}/data/malicious11.pkl")) + assert malicious11.issues.all_issues == expected_malicious11 + def test_scan_directory_path(file_path: str) -> None: expected = { @@ -761,6 +802,11 @@ def test_scan_directory_path(file_path: str) -> None: "__builtin__", "exec", f"{file_path}/data/malicious10.pkl" ), ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails("os", "system", f"{file_path}/data/malicious11.pkl"), + ), } ms = ModelScan() p = Path(f"{file_path}/data/")