Skip to content

Commit

Permalink
Merge branch 'main' into 107-error-field
Browse files Browse the repository at this point in the history
  • Loading branch information
mehrinkiani committed Feb 20, 2024
2 parents 1aa3dbb + 5a6e5dd commit b8e78d5
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 7 deletions.
11 changes: 10 additions & 1 deletion modelscan/modelscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,19 @@ def _scan_zip(
file_names = zip.namelist()
for file_name in file_names:
with zip.open(file_name, "r") as file_io:
self._scan_source(
scanned = self._scan_source(
source=f"{source}:{file_name}",
data=file_io,
)
if not scanned:
if _is_zipfile(file_name, data=file_io):
self._errors.append(
ModelScanError(
"ModelScan",
f"{source}:{file_name} is a zip file. ModelScan does not support nested zip files.",
)
)
self._skipped.append(f"{source}:{file_name}")
except zipfile.BadZipFile as e:
logger.debug(f"Skipping zip file {source}, due to error", e, exc_info=True)
self._skipped.append(str(source))
Expand Down
2 changes: 1 addition & 1 deletion modelscan/scanners/pickle/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def scan(
):
return None

if _is_zipfile(source):
if _is_zipfile(source, data):
return None

if data:
Expand Down
31 changes: 26 additions & 5 deletions tests/test_modelscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,11 @@ def test_scan_pytorch(pytorch_file_path: Any) -> None:
assert results["summary"]["scanned"]["scanned_files"] == [
f"safe_zip_pytorch.pt:safe_zip_pytorch/data.pkl"
]
assert results["summary"]["skipped"]["skipped_files"] == []
assert results["summary"]["skipped"]["skipped_files"] == [
"safe_zip_pytorch.pt:safe_zip_pytorch/byteorder",
"safe_zip_pytorch.pt:safe_zip_pytorch/version",
"safe_zip_pytorch.pt:safe_zip_pytorch/.data/serialization_id",
]
assert ms.issues.all_issues == []
assert results["errors"] == []

Expand All @@ -499,9 +503,13 @@ def test_scan_pytorch(pytorch_file_path: Any) -> None:
]
results = ms.scan(unsafe_zip_path)
assert results["summary"]["scanned"]["scanned_files"] == [
f"unsafe_zip_pytorch.pt:unsafe_zip_pytorch/data.pkl"
f"unsafe_zip_pytorch.pt:unsafe_zip_pytorch/data.pkl",
]
assert results["summary"]["skipped"]["skipped_files"] == [
"unsafe_zip_pytorch.pt:unsafe_zip_pytorch/byteorder",
"unsafe_zip_pytorch.pt:unsafe_zip_pytorch/version",
"unsafe_zip_pytorch.pt:unsafe_zip_pytorch/.data/serialization_id",
]
assert results["summary"]["skipped"]["skipped_files"] == []
assert ms.issues.all_issues == expected
assert results["errors"] == []

Expand Down Expand Up @@ -1260,7 +1268,10 @@ def test_scan_keras(keras_file_path: Any, file_extension: str) -> None:
f"safe{file_extension}",
f"safe{file_extension}:model.weights.h5",
]
assert results["summary"]["skipped"]["skipped_files"] == []
assert results["summary"]["skipped"]["skipped_files"] == [
"safe.keras:metadata.json",
"safe.keras:config.json",
]
assert results["errors"] == [
{
"description": "modelscan.scanners.H5LambdaDetectScan got data bytes. It only support direct file scanning.",
Expand All @@ -1271,7 +1282,16 @@ def test_scan_keras(keras_file_path: Any, file_extension: str) -> None:
assert results["summary"]["scanned"]["scanned_files"] == [
f"safe{file_extension}"
]
assert results["summary"]["skipped"]["skipped_files"] == []

if file_extension == ".keras":
assert results["summary"]["skipped"]["skipped_files"] == [
f"safe{file_extension}:metadata.json",
f"safe{file_extension}:config.json",
f"safe{file_extension}:model.weights.h5",
]
else:
assert results["summary"]["skipped"]["skipped_files"] == []

assert results["errors"] == []

unsafe_filename = ""
Expand Down Expand Up @@ -1370,6 +1390,7 @@ def test_scan_keras(keras_file_path: Any, file_extension: str) -> None:
]

results = ms.scan(Path(f"{keras_file_path_parent_dir}/unsafe{file_extension}"))

assert ms.issues.all_issues == expected
assert results["errors"] == []
assert results["summary"]["skipped"]["skipped_files"] == []
Expand Down

0 comments on commit b8e78d5

Please sign in to comment.