Skip to content

Commit

Permalink
* simplify code based on the feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
asofter committed Mar 20, 2024
1 parent c110330 commit 1c5fa7f
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 31 deletions.
48 changes: 32 additions & 16 deletions modelscan/modelscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,26 +154,30 @@ def scan(
model_path = Path(pathlib_path)

all_paths: List[Path] = []
scanned_paths: List[Path] = []
for model in self._iterate_models(model_path):
self._middleware_pipeline.run(model)
scanned = self._scan_source(model)
if scanned:
scanned_paths.append(model.get_source())

self._scan_source(model)
all_paths.append(model.get_source())

skipped_paths = list(set(all_paths) - set(scanned_paths))
if skipped_paths:
for skipped_path in skipped_paths:
self._skipped.append(
ModelScanSkipped(
"ModelScan",
SkipCategories.SCAN_NOT_SUPPORTED,
f"Model Scan did not scan file",
str(skipped_path),
)
)
if self._skipped:
all_skipped_paths = [skipped.source for skipped in self._skipped]
for skipped in self._skipped:
main_file_path = skipped.source.split(":")[0]
if main_file_path == skipped.source:
continue

# If main container is skipped, we only add its content to skipped but not the file itself
if main_file_path in all_skipped_paths:
self._skipped = [
item for item in self._skipped if item.source != main_file_path
]

continue

# If main container is scanned, we consider all files to be scanned
self._skipped = [
item for item in self._skipped if item.source != skipped.source
]

return self._generate_results()

Expand Down Expand Up @@ -217,6 +221,18 @@ def _scan_source(
else:
self._scanned.append(str(model.get_source()))

if not scanned:
all_skipped_files = [skipped.source for skipped in self._skipped]
if str(model.get_source()) not in all_skipped_files:
self._skipped.append(
ModelScanSkipped(
"ModelScan",
SkipCategories.SCAN_NOT_SUPPORTED,
f"Model Scan did not scan file",
str(model.get_source()),
)
)

return scanned

def _generate_results(self) -> Dict[str, Any]:
Expand Down
16 changes: 1 addition & 15 deletions tests/test_modelscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,6 @@ def test_scan_pytorch(pytorch_file_path: Any) -> None:
for skipped_file in results["summary"]["skipped"]["skipped_files"]
]
) == {
"safe_zip_pytorch.pt",
"safe_zip_pytorch.pt:safe_zip_pytorch/byteorder",
"safe_zip_pytorch.pt:safe_zip_pytorch/version",
"safe_zip_pytorch.pt:safe_zip_pytorch/.data/serialization_id",
Expand Down Expand Up @@ -522,7 +521,6 @@ def test_scan_pytorch(pytorch_file_path: Any) -> None:
for skipped_file in results["summary"]["skipped"]["skipped_files"]
]
) == {
"unsafe_zip_pytorch.pt",
"unsafe_zip_pytorch.pt:unsafe_zip_pytorch/byteorder",
"unsafe_zip_pytorch.pt:unsafe_zip_pytorch/version",
"unsafe_zip_pytorch.pt:unsafe_zip_pytorch/.data/serialization_id",
Expand Down Expand Up @@ -1295,19 +1293,7 @@ def test_scan_keras(keras_file_path: Any, file_extension: str) -> None:
f"safe{file_extension}"
]

if file_extension == ".keras":
assert set(
[
skipped_file["source"]
for skipped_file in results["summary"]["skipped"]["skipped_files"]
]
) == {
f"safe{file_extension}:metadata.json",
f"safe{file_extension}:config.json",
f"safe{file_extension}:model.weights.h5",
}
else:
assert results["summary"]["skipped"]["skipped_files"] == []
assert results["summary"]["skipped"]["skipped_files"] == []

assert results["errors"] == []

Expand Down

0 comments on commit 1c5fa7f

Please sign in to comment.