diff --git a/tests/test_modelscan.py b/tests/test_modelscan.py index 18d2418..1900bad 100644 --- a/tests/test_modelscan.py +++ b/tests/test_modelscan.py @@ -173,20 +173,20 @@ def pytorch_file_path(tmp_path_factory: Any) -> Any: initialize_data_file(f"{tmp}/bad_pytorch.pt", b"\211PNG\r\n\032\n") # Safe PyTorch files in old and new (zip) formats - model = PyTorchTestModel() + pt = PyTorchTestModel() torch.save( - model.state_dict(), + pt.model.state_dict(), f=f"{tmp}/safe_zip_pytorch.pt", _use_new_zipfile_serialization=True, ) torch.save( - model.state_dict(), + pt.model.state_dict(), f=f"{tmp}/safe_old_format_pytorch.pt", _use_new_zipfile_serialization=False, ) # Unsafe PyTorch files in new (zip) format - model.generate_unsafe_pytorch_file( + pt.generate_unsafe_pytorch_file( unsafe_file_path=f"{tmp}/unsafe_zip_pytorch.pt", model_path=f"{tmp}/safe_zip_pytorch.pt", zipfile=True, diff --git a/tests/test_utils.py b/tests/test_utils.py index 8a836d4..4bc4100 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -238,22 +238,9 @@ def generate_dill_unsafe_file( file_for_unsafe_model.close() -class PyTorchTestModel(nn.Module): # type: ignore[misc] +class PyTorchTestModel: def __init__(self) -> None: - super().__init__() - self.flatten = nn.Flatten() - self.linear_relu_stack = nn.Sequential( - nn.Linear(28 * 28, 512), - nn.ReLU(), - nn.Linear(512, 512), - nn.ReLU(), - nn.Linear(512, 10), - ) - - def forward(self, x: Any) -> Any: - x = self.flatten(x) - logits = self.linear_relu_stack(x) - return logits + self.model = nn.Module() def generate_unsafe_pytorch_file( self, unsafe_file_path: str, model_path: str, zipfile: bool = True