diff --git a/modelscan/scanners/pickle/scan.py b/modelscan/scanners/pickle/scan.py index 033667d..1f2fc6c 100644 --- a/modelscan/scanners/pickle/scan.py +++ b/modelscan/scanners/pickle/scan.py @@ -3,6 +3,7 @@ from typing import IO, Union, Optional from modelscan.scanners.scan import ScanBase, ScanResults +from modelscan.tools.utils import _is_zipfile from modelscan.tools.picklescanner import ( scan_numpy, scan_pickle_bytes, @@ -26,6 +27,9 @@ def scan( ): return None + if _is_zipfile(source): + return None + if data: results = scan_pytorch(data=data, source=source, settings=self._settings) diff --git a/tests/test_modelscan.py b/tests/test_modelscan.py index 99f2fd4..18d2418 100644 --- a/tests/test_modelscan.py +++ b/tests/test_modelscan.py @@ -12,10 +12,15 @@ import socket import subprocess import sys +import torch import tensorflow as tf from tensorflow import keras from typing import Any, List, Set, Dict -from test_utils import generate_dill_unsafe_file, generate_unsafe_pickle_file +from test_utils import ( + generate_dill_unsafe_file, + generate_unsafe_pickle_file, + PyTorchTestModel, +) import zipfile from modelscan.modelscan import ModelScan @@ -166,6 +171,27 @@ def pytorch_file_path(tmp_path_factory: Any) -> Any: tmp = tmp_path_factory.mktemp("pytorch") # Fake PyTorch file (PNG file format) simulating https://huggingface.co/RectalWorm/loras_new/blob/main/Owl_Mage_no_background.pt initialize_data_file(f"{tmp}/bad_pytorch.pt", b"\211PNG\r\n\032\n") + + # Safe PyTorch files in old and new (zip) formats + model = PyTorchTestModel() + torch.save( + model.state_dict(), + f=f"{tmp}/safe_zip_pytorch.pt", + _use_new_zipfile_serialization=True, + ) + torch.save( + model.state_dict(), + f=f"{tmp}/safe_old_format_pytorch.pt", + _use_new_zipfile_serialization=False, + ) + + # Unsafe PyTorch files in new (zip) format + model.generate_unsafe_pytorch_file( + unsafe_file_path=f"{tmp}/unsafe_zip_pytorch.pt", + model_path=f"{tmp}/safe_zip_pytorch.pt", + zipfile=True, + ) + return tmp @@ -382,10 +408,32 @@ def test_scan_zip(zip_file_path: Any) -> None: def test_scan_pytorch(pytorch_file_path: Any) -> None: - bad_pytorch = ModelScan() - bad_pytorch.scan(Path(f"{pytorch_file_path}/bad_pytorch.pt")) - assert bad_pytorch.issues.all_issues == [] - assert [error.scan_name for error in bad_pytorch.errors] == ["pytorch"] # type: ignore[attr-defined] + ms = ModelScan() + ms.scan(Path(f"{pytorch_file_path}/bad_pytorch.pt")) + assert ms.issues.all_issues == [] + assert [error.scan_name for error in ms.errors] == ["pytorch"] # type: ignore[attr-defined] + + ms.scan(Path(f"{pytorch_file_path}/safe_zip_pytorch.pt")) + assert ms.issues.all_issues == [] + assert ms.errors == [] + + ms.scan(Path(f"{pytorch_file_path}/safe_old_format_pytorch.pt")) + assert ms.issues.all_issues == [] + assert ms.errors == [] + + unsafe_zip_path = f"{pytorch_file_path}/unsafe_zip_pytorch.pt" + expected = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.CRITICAL, + OperatorIssueDetails( + "posix", "system", f"{unsafe_zip_path}:unsafe_zip_pytorch/data.pkl" + ), + ), + ] + ms.scan(unsafe_zip_path) + assert ms.errors == [] + assert ms.issues.all_issues == expected def test_scan_numpy(numpy_file_path: Any) -> None: diff --git a/tests/test_utils.py b/tests/test_utils.py index 34b3eeb..8a836d4 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -4,6 +4,8 @@ import struct from typing import Any, Tuple import os +import torch +import torch.nn as nn class PickleInject: @@ -234,3 +236,36 @@ def generate_dill_unsafe_file( mypickler = DillInject._Pickler(file_for_unsafe_model, pickle_protocol, [payload]) mypickler.dump(safe_model) file_for_unsafe_model.close() + + +class PyTorchTestModel(nn.Module): # type: ignore[misc] + def __init__(self) -> None: + super().__init__() + self.flatten = nn.Flatten() + self.linear_relu_stack = nn.Sequential( + nn.Linear(28 * 28, 512), + nn.ReLU(), + nn.Linear(512, 512), + nn.ReLU(), + nn.Linear(512, 10), + ) + + def forward(self, x: Any) -> Any: + x = self.flatten(x) + logits = self.linear_relu_stack(x) + return logits + + def generate_unsafe_pytorch_file( + self, unsafe_file_path: str, model_path: str, zipfile: bool = True + ) -> None: + command = "system" + malicious_code = """cat ~/.aws/secrets + """ + + payload = get_pickle_payload(command, malicious_code) + torch.save( + torch.load(model_path), + f=unsafe_file_path, + pickle_module=PickleInject([payload]), + _use_new_zipfile_serialization=zipfile, + )