Skip to content

Commit

Permalink
simplify test model
Browse files Browse the repository at this point in the history
  • Loading branch information
swashko committed Jan 19, 2024
1 parent 6d1349e commit e1a12d0
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 19 deletions.
8 changes: 4 additions & 4 deletions tests/test_modelscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,20 +173,20 @@ def pytorch_file_path(tmp_path_factory: Any) -> Any:
initialize_data_file(f"{tmp}/bad_pytorch.pt", b"\211PNG\r\n\032\n")

# Safe PyTorch files in old and new (zip) formats
model = PyTorchTestModel()
pt = PyTorchTestModel()
torch.save(
model.state_dict(),
pt.model.state_dict(),
f=f"{tmp}/safe_zip_pytorch.pt",
_use_new_zipfile_serialization=True,
)
torch.save(
model.state_dict(),
pt.model.state_dict(),
f=f"{tmp}/safe_old_format_pytorch.pt",
_use_new_zipfile_serialization=False,
)

# Unsafe PyTorch files in new (zip) format
model.generate_unsafe_pytorch_file(
pt.generate_unsafe_pytorch_file(
unsafe_file_path=f"{tmp}/unsafe_zip_pytorch.pt",
model_path=f"{tmp}/safe_zip_pytorch.pt",
zipfile=True,
Expand Down
17 changes: 2 additions & 15 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,22 +238,9 @@ def generate_dill_unsafe_file(
file_for_unsafe_model.close()


class PyTorchTestModel(nn.Module): # type: ignore[misc]
class PyTorchTestModel:
def __init__(self) -> None:
super().__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28 * 28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
)

def forward(self, x: Any) -> Any:
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
self.model = nn.Module()

def generate_unsafe_pytorch_file(
self, unsafe_file_path: str, model_path: str, zipfile: bool = True
Expand Down

0 comments on commit e1a12d0

Please sign in to comment.