Skip to content

Commit

Permalink
Merge branch 'main' into nested-zip
Browse files Browse the repository at this point in the history
  • Loading branch information
swashko committed Feb 13, 2024
2 parents 4f418a5 + fd71c7b commit a18d17a
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 29 deletions.
2 changes: 1 addition & 1 deletion modelscan/scanners/h5/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def scan(
def _scan_keras_h5_file(self, source: Union[str, Path]) -> Optional[ScanResults]:
machine_learning_library_name = "Keras"
operators_in_model = self._get_keras_h5_operator_names(source)
if not operators_in_model:
if operators_in_model is None:
return None
return H5LambdaDetectScan._check_for_unsafe_tf_keras_operator(
module_name=machine_learning_library_name,
Expand Down
183 changes: 155 additions & 28 deletions tests/test_modelscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,23 +454,35 @@ def test_scan_zip(zip_file_path: Any) -> None:
]

ms = ModelScan()
ms._scan_zip(f"{zip_file_path}/test.zip")
results = ms.scan(f"{zip_file_path}/test.zip")
assert results["summary"]["scanned"]["scanned_files"] == [f"test.zip:data.pkl"]
assert results["summary"]["skipped"]["skipped_files"] == []
assert ms.issues.all_issues == expected


def test_scan_pytorch(pytorch_file_path: Any) -> None:
ms = ModelScan()
ms.scan(Path(f"{pytorch_file_path}/bad_pytorch.pt"))
results = ms.scan(Path(f"{pytorch_file_path}/bad_pytorch.pt"))
assert results["summary"]["scanned"]["scanned_files"] == [f"bad_pytorch.pt"]
assert results["summary"]["skipped"]["skipped_files"] == []
assert ms.issues.all_issues == []
assert [error.scan_name for error in ms.errors] == ["pytorch"] # type: ignore[attr-defined]

ms.scan(Path(f"{pytorch_file_path}/safe_zip_pytorch.pt"))
results = ms.scan(Path(f"{pytorch_file_path}/safe_zip_pytorch.pt"))
assert results["summary"]["scanned"]["scanned_files"] == [
f"safe_zip_pytorch.pt:safe_zip_pytorch/data.pkl"
]
assert results["summary"]["skipped"]["skipped_files"] == []
assert ms.issues.all_issues == []
assert ms.errors == []
assert results["errors"] == []

ms.scan(Path(f"{pytorch_file_path}/safe_old_format_pytorch.pt"))
results = ms.scan(Path(f"{pytorch_file_path}/safe_old_format_pytorch.pt"))
assert results["summary"]["scanned"]["scanned_files"] == [
f"safe_old_format_pytorch.pt"
]
assert results["summary"]["skipped"]["skipped_files"] == []
assert ms.issues.all_issues == []
assert ms.errors == []
assert results["errors"] == []

unsafe_zip_path = f"{pytorch_file_path}/unsafe_zip_pytorch.pt"
expected = [
Expand All @@ -485,39 +497,57 @@ def test_scan_pytorch(pytorch_file_path: Any) -> None:
),
),
]
ms.scan(unsafe_zip_path)
assert ms.errors == []
results = ms.scan(unsafe_zip_path)
assert results["summary"]["scanned"]["scanned_files"] == [
f"unsafe_zip_pytorch.pt:unsafe_zip_pytorch/data.pkl"
]
assert results["summary"]["skipped"]["skipped_files"] == []
assert ms.issues.all_issues == expected
assert results["errors"] == []


def test_scan_numpy(numpy_file_path: Any) -> None:
with open(f"{numpy_file_path}/safe_numpy.npy", "rb") as f:
assert scan_numpy(io.BytesIO(f.read()), "safe_numpy.npy", settings).issues == []
ms = ModelScan()
results = ms.scan(f"{numpy_file_path}/safe_numpy.npy")
assert ms.issues.all_issues == []
assert results["summary"]["scanned"]["scanned_files"] == [f"safe_numpy.npy"]
assert results["summary"]["skipped"]["skipped_files"] == []
assert results["errors"] == []

expected = {
Issue(
IssueCode.UNSAFE_OPERATOR,
IssueSeverity.CRITICAL,
OperatorIssueDetails(
"builtins", "exec", IssueSeverity.CRITICAL, "unsafe_numpy.npy"
"builtins",
"exec",
IssueSeverity.CRITICAL,
f"{numpy_file_path}/unsafe_numpy.npy",
),
),
}

with open(f"{numpy_file_path}/unsafe_numpy.npy", "rb") as f:
compare_results(
scan_numpy(io.BytesIO(f.read()), "unsafe_numpy.npy", settings).issues,
expected,
)
results = ms.scan(f"{numpy_file_path}/unsafe_numpy.npy")
compare_results(ms.issues.all_issues, expected)
assert results["summary"]["scanned"]["scanned_files"] == [f"unsafe_numpy.npy"]
assert results["summary"]["skipped"]["skipped_files"] == []
assert results["errors"] == []


def test_scan_file_path(file_path: Any) -> None:
benign_pickle = ModelScan()
benign_pickle.scan(Path(f"{file_path}/data/benign0_v3.pkl"))
benign_dill = ModelScan()
benign_dill.scan(Path(f"{file_path}/data/benign0_v3.dill"))
results = benign_pickle.scan(Path(f"{file_path}/data/benign0_v3.pkl"))
assert benign_pickle.issues.all_issues == []
assert results["summary"]["scanned"]["scanned_files"] == [f"benign0_v3.pkl"]
assert results["summary"]["skipped"]["skipped_files"] == []
assert results["errors"] == []

benign_dill = ModelScan()
results = benign_dill.scan(Path(f"{file_path}/data/benign0_v3.dill"))
assert benign_dill.issues.all_issues == []
assert results["summary"]["scanned"]["scanned_files"] == [f"benign0_v3.dill"]
assert results["summary"]["skipped"]["skipped_files"] == []
assert results["errors"] == []

malicious0 = ModelScan()
expected_malicious0 = {
Expand Down Expand Up @@ -562,8 +592,11 @@ def test_scan_file_path(file_path: Any) -> None:
),
),
}
malicious0.scan(Path(f"{file_path}/data/malicious0.pkl"))
results = malicious0.scan(Path(f"{file_path}/data/malicious0.pkl"))
compare_results(malicious0.issues.all_issues, expected_malicious0)
assert results["summary"]["scanned"]["scanned_files"] == [f"malicious0.pkl"]
assert results["summary"]["skipped"]["skipped_files"] == []
assert results["errors"] == []


def test_scan_pickle_operators(file_path: Any) -> None:
Expand Down Expand Up @@ -1154,8 +1187,40 @@ def test_scan_directory_path(file_path: str) -> None:
}
ms = ModelScan()
p = Path(f"{file_path}/data/")
ms.scan(p)
results = ms.scan(p)
compare_results(ms.issues.all_issues, expected)
assert set(results["summary"]["scanned"]["scanned_files"]) == {
f"malicious1.zip:data.pkl",
f"malicious0.pkl",
f"malicious3.pkl",
f"malicious6.pkl",
f"malicious7.pkl",
f"malicious8.pkl",
f"malicious9.pkl",
f"malicious10.pkl",
f"malicious11.pkl",
f"malicious12.pkl",
f"malicious13.pkl",
f"malicious1_v0.dill",
f"malicious1_v3.dill",
f"malicious1_v4.dill",
f"malicious4.pickle",
f"malicious5.pickle",
f"malicious1_v0.pkl",
f"malicious1_v3.pkl",
f"malicious1_v4.pkl",
f"malicious2_v0.pkl",
f"malicious2_v3.pkl",
f"malicious2_v4.pkl",
f"benign0_v0.pkl",
f"benign0_v3.pkl",
f"benign0_v4.pkl",
f"benign0_v0.dill",
f"benign0_v3.dill",
f"benign0_v4.dill",
}
assert results["summary"]["skipped"]["skipped_files"] == []
assert results["errors"] == []


@pytest.mark.parametrize(
Expand All @@ -1168,14 +1233,37 @@ def test_scan_keras(keras_file_path: Any, file_extension: str) -> None:
keras_file_path[2],
)
ms = ModelScan()
results = {}
safe_filename = ""
if file_extension == ".pb":
ms.scan(Path(f"{safe_saved_model_dir}"))
safe_filename = f"{safe_saved_model_dir}"
else:
ms.scan(Path(f"{keras_file_path_parent_dir}/safe{file_extension}"))
safe_filename = f"{keras_file_path_parent_dir}/safe{file_extension}"

results = ms.scan(Path(safe_filename))

assert ms.issues.all_issues == []
if file_extension == ".pb":
assert set(results["summary"]["scanned"]["scanned_files"]) == {
f"fingerprint.pb",
f"keras_metadata.pb",
f"saved_model.pb",
}
assert set(results["summary"]["skipped"]["skipped_files"]) == {
f"variables/variables.data-00000-of-00001",
f"variables/variables.index",
}
else:
assert results["summary"]["scanned"]["scanned_files"] == [
f"safe{file_extension}"
]
assert results["summary"]["skipped"]["skipped_files"] == []

assert results["errors"] == []

unsafe_filename = ""
if file_extension == ".keras":
unsafe_filename = f"{keras_file_path_parent_dir}/unsafe{file_extension}"
expected = [
Issue(
IssueCode.UNSAFE_OPERATOR,
Expand All @@ -1198,9 +1286,10 @@ def test_scan_keras(keras_file_path: Any, file_extension: str) -> None:
),
),
]
ms.scan(Path(f"{keras_file_path_parent_dir}/unsafe{file_extension}"))
results = ms.scan(Path(f"{keras_file_path_parent_dir}/unsafe{file_extension}"))
elif file_extension == ".pb":
file_name = "keras_metadata.pb"
unsafe_filename = f"{unsafe_saved_model_dir}"
expected = [
Issue(
IssueCode.UNSAFE_OPERATOR,
Expand All @@ -1223,8 +1312,9 @@ def test_scan_keras(keras_file_path: Any, file_extension: str) -> None:
),
),
]
ms.scan(Path(f"{unsafe_saved_model_dir}"))
results = ms.scan(Path(f"{unsafe_saved_model_dir}"))
else:
unsafe_filename = f"{keras_file_path_parent_dir}/unsafe{file_extension}"
expected = [
Issue(
IssueCode.UNSAFE_OPERATOR,
Expand All @@ -1248,8 +1338,25 @@ def test_scan_keras(keras_file_path: Any, file_extension: str) -> None:
),
]

ms.scan(Path(f"{keras_file_path_parent_dir}/unsafe{file_extension}"))
results = ms.scan(Path(f"{keras_file_path_parent_dir}/unsafe{file_extension}"))
assert ms.issues.all_issues == expected
assert results["errors"] == []

if file_extension == ".pb":
assert set(results["summary"]["scanned"]["scanned_files"]) == {
f"fingerprint.pb",
f"keras_metadata.pb",
f"saved_model.pb",
}
assert set(results["summary"]["skipped"]["skipped_files"]) == {
f"variables/variables.data-00000-of-00001",
f"variables/variables.index",
}
else:
assert results["summary"]["scanned"]["scanned_files"] == [
f"unsafe{file_extension}"
]
assert results["summary"]["skipped"]["skipped_files"] == []


def test_scan_tensorflow(tensorflow_file_path: Any) -> None:
Expand All @@ -1258,8 +1365,18 @@ def test_scan_tensorflow(tensorflow_file_path: Any) -> None:
tensorflow_file_path[1],
)
ms = ModelScan()
ms.scan(Path(f"{safe_tensorflow_model_dir}"))
results = ms.scan(Path(f"{safe_tensorflow_model_dir}"))
assert ms.issues.all_issues == []
assert set(results["summary"]["scanned"]["scanned_files"]) == {
f"fingerprint.pb",
f"keras_metadata.pb",
f"saved_model.pb",
}
assert set(results["summary"]["skipped"]["skipped_files"]) == {
f"variables/variables.data-00000-of-00001",
f"variables/variables.index",
}
assert results["errors"] == []

file_name = "saved_model.pb"
expected = [
Expand All @@ -1284,9 +1401,19 @@ def test_scan_tensorflow(tensorflow_file_path: Any) -> None:
),
),
]
ms.scan(Path(f"{unsafe_tensorflow_model_dir}"))
results = ms.scan(Path(f"{unsafe_tensorflow_model_dir}"))

assert ms.issues.all_issues == expected
assert set(results["summary"]["scanned"]["scanned_files"]) == {
f"fingerprint.pb",
f"keras_metadata.pb",
f"saved_model.pb",
}
assert set(results["summary"]["skipped"]["skipped_files"]) == {
f"variables/variables.data-00000-of-00001",
f"variables/variables.index",
}
assert results["errors"] == []


def test_main(file_path: Any) -> None:
Expand Down

0 comments on commit a18d17a

Please sign in to comment.