Skip to content

Commit

Permalink
pytorch zipfile check and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
swashko committed Jan 19, 2024
1 parent 3b20681 commit 8e570f6
Showing 3 changed files with 92 additions and 5 deletions.
4 changes: 4 additions & 0 deletions modelscan/scanners/pickle/scan.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@
from typing import IO, Union, Optional

from modelscan.scanners.scan import ScanBase, ScanResults
from modelscan.tools.utils import _is_zipfile
from modelscan.tools.picklescanner import (
scan_numpy,
scan_pickle_bytes,
@@ -26,6 +27,9 @@ def scan(
):
return None

if _is_zipfile(source):
return None

if data:
results = scan_pytorch(data=data, source=source, settings=self._settings)

58 changes: 53 additions & 5 deletions tests/test_modelscan.py
Original file line number Diff line number Diff line change
@@ -12,10 +12,15 @@
import socket
import subprocess
import sys
import torch
import tensorflow as tf
from tensorflow import keras
from typing import Any, List, Set, Dict
from test_utils import generate_dill_unsafe_file, generate_unsafe_pickle_file
from test_utils import (
generate_dill_unsafe_file,
generate_unsafe_pickle_file,
PyTorchTestModel,
)
import zipfile

from modelscan.modelscan import ModelScan
@@ -166,6 +171,27 @@ def pytorch_file_path(tmp_path_factory: Any) -> Any:
tmp = tmp_path_factory.mktemp("pytorch")
# Fake PyTorch file (PNG file format) simulating https://huggingface.co/RectalWorm/loras_new/blob/main/Owl_Mage_no_background.pt
initialize_data_file(f"{tmp}/bad_pytorch.pt", b"\211PNG\r\n\032\n")

# Safe PyTorch files in old and new (zip) formats
model = PyTorchTestModel()
torch.save(
model.state_dict(),
f=f"{tmp}/safe_zip_pytorch.pt",
_use_new_zipfile_serialization=True,
)
torch.save(
model.state_dict(),
f=f"{tmp}/safe_old_format_pytorch.pt",
_use_new_zipfile_serialization=False,
)

# Unsafe PyTorch files in new (zip) format
model.generate_unsafe_pytorch_file(
unsafe_file_path=f"{tmp}/unsafe_zip_pytorch.pt",
model_path=f"{tmp}/safe_zip_pytorch.pt",
zipfile=True,
)

return tmp


@@ -382,10 +408,32 @@ def test_scan_zip(zip_file_path: Any) -> None:


def test_scan_pytorch(pytorch_file_path: Any) -> None:
bad_pytorch = ModelScan()
bad_pytorch.scan(Path(f"{pytorch_file_path}/bad_pytorch.pt"))
assert bad_pytorch.issues.all_issues == []
assert [error.scan_name for error in bad_pytorch.errors] == ["pytorch"] # type: ignore[attr-defined]
ms = ModelScan()
ms.scan(Path(f"{pytorch_file_path}/bad_pytorch.pt"))
assert ms.issues.all_issues == []
assert [error.scan_name for error in ms.errors] == ["pytorch"] # type: ignore[attr-defined]

ms.scan(Path(f"{pytorch_file_path}/safe_zip_pytorch.pt"))
assert ms.issues.all_issues == []
assert ms.errors == []

ms.scan(Path(f"{pytorch_file_path}/safe_old_format_pytorch.pt"))
assert ms.issues.all_issues == []
assert ms.errors == []

unsafe_zip_path = f"{pytorch_file_path}/unsafe_zip_pytorch.pt"
expected = [
Issue(
IssueCode.UNSAFE_OPERATOR,
IssueSeverity.CRITICAL,
OperatorIssueDetails(
"posix", "system", f"{unsafe_zip_path}:unsafe_zip_pytorch/data.pkl"
),
),
]
ms.scan(unsafe_zip_path)
assert ms.errors == []
assert ms.issues.all_issues == expected


def test_scan_numpy(numpy_file_path: Any) -> None:
35 changes: 35 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,8 @@
import struct
from typing import Any, Tuple
import os
import torch
import torch.nn as nn


class PickleInject:
@@ -234,3 +236,36 @@ def generate_dill_unsafe_file(
mypickler = DillInject._Pickler(file_for_unsafe_model, pickle_protocol, [payload])
mypickler.dump(safe_model)
file_for_unsafe_model.close()


class PyTorchTestModel(nn.Module): # type: ignore[misc]
def __init__(self) -> None:
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28 * 28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)

def forward(self, x: Any) -> Any:
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits

def generate_unsafe_pytorch_file(
self, unsafe_file_path: str, model_path: str, zipfile: bool = True
) -> None:
command = "system"
malicious_code = """cat ~/.aws/secrets
"""

payload = get_pickle_payload(command, malicious_code)
torch.save(
torch.load(model_path),
f=unsafe_file_path,
pickle_module=PickleInject([payload]),
_use_new_zipfile_serialization=zipfile,
)

0 comments on commit 8e570f6

Please sign in to comment.