Skip to content

Commit

Permalink
Format
Browse files Browse the repository at this point in the history
  • Loading branch information
mehrinkiani committed Feb 7, 2024
1 parent 68fb048 commit 2b352c1
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 22 deletions.
43 changes: 23 additions & 20 deletions tests/test_modelscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
import torch
import tensorflow as tf
from tensorflow import keras
from typing import Any, List, Set
from typing import Any, List, Set, Dict
from test_utils import (
generate_dill_unsafe_file,
generate_unsafe_pickle_file,
MaliciousModule,
PyTorchTestModel,
)
import zipfile

Expand All @@ -39,7 +40,7 @@

settings: Dict[str, Any] = DEFAULT_SETTINGS

from modelscan.models.saved_model.scan import SavedModelScan
from modelscan.scanners.saved_model.scan import SavedModelScan


class Malicious1:
Expand Down Expand Up @@ -998,17 +999,16 @@ def test_scan_directory_path(file_path: str) -> None:
"file_extension", [".h5", ".keras", ".pb"], ids=["h5", "keras", "pb"]
)
def test_scan_keras(keras_file_path: Any, file_extension: str) -> None:

keras_file_path_parent_dir, safe_saved_model_dir, unsafe_saved_model_dir = (
keras_file_path[0],
keras_file_path[1],
keras_file_path[2],
)
ms = Modelscan()
ms = ModelScan()
if file_extension == ".pb":
ms.scan_path(Path(f"{safe_saved_model_dir}"))
ms.scan(Path(f"{safe_saved_model_dir}"))
else:
ms.scan_path(Path(f"{keras_file_path_parent_dir}/safe{file_extension}"))
ms.scan(Path(f"{keras_file_path_parent_dir}/safe{file_extension}"))

assert ms.issues.all_issues == []

Expand All @@ -1029,16 +1029,11 @@ def test_scan_keras(keras_file_path: Any, file_extension: str) -> None:
OperatorIssueDetails(
"Keras",
"Lambda",
f"{keras_file_path}/unsafe{file_extension}:config.json",
f"{keras_file_path_parent_dir}/unsafe{file_extension}:config.json",
),
),
]
ms._scan_source(

Path(f"{keras_file_path_parent_dir}/unsafe{file_extension}"),
extension=file_extension,

)
ms.scan(Path(f"{keras_file_path_parent_dir}/unsafe{file_extension}"))
elif file_extension == ".pb":
file_name = "keras_metadata.pb"
expected = [
Expand All @@ -1051,8 +1046,17 @@ def test_scan_keras(keras_file_path: Any, file_extension: str) -> None:
f"{unsafe_saved_model_dir}/{file_name}",
),
),
Issue(
IssueCode.UNSAFE_OPERATOR,
IssueSeverity.MEDIUM,
OperatorIssueDetails(
"Keras",
"Lambda",
f"{unsafe_saved_model_dir}/{file_name}",
),
),
]
ms.scan_path(Path(f"{unsafe_saved_model_dir}"))
ms.scan(Path(f"{unsafe_saved_model_dir}"))
else:
expected = [
Issue(
Expand All @@ -1070,12 +1074,12 @@ def test_scan_keras(keras_file_path: Any, file_extension: str) -> None:
OperatorIssueDetails(
"Keras",
"Lambda",
f"{keras_file_path}/unsafe{file_extension}",
f"{keras_file_path_parent_dir}/unsafe{file_extension}",
),
),
]

ms.scan_path(Path(f"{keras_file_path_parent_dir}/unsafe{file_extension}"))
ms.scan(Path(f"{keras_file_path_parent_dir}/unsafe{file_extension}"))
assert ms.issues.all_issues == expected


Expand All @@ -1084,8 +1088,8 @@ def test_scan_tensorflow(tensorflow_file_path: Any) -> None:
tensorflow_file_path[0],
tensorflow_file_path[1],
)
ms = Modelscan()
ms.scan_path(Path(f"{safe_tensorflow_model_dir}"))
ms = ModelScan()
ms.scan(Path(f"{safe_tensorflow_model_dir}"))
assert ms.issues.all_issues == []

file_name = "saved_model.pb"
Expand All @@ -1109,8 +1113,7 @@ def test_scan_tensorflow(tensorflow_file_path: Any) -> None:
),
),
]
ms.scan_path(Path(f"{unsafe_tensorflow_model_dir}"))

ms.scan(Path(f"{unsafe_tensorflow_model_dir}"))

assert ms.issues.all_issues == expected

Expand Down
2 changes: 0 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ def call(self, x: float) -> Any:
return res




class PickleInject:
"""Pickle injection"""

Expand Down

0 comments on commit 2b352c1

Please sign in to comment.