diff --git a/modelscan/modelscan.py b/modelscan/modelscan.py index 1592f14..2745990 100644 --- a/modelscan/modelscan.py +++ b/modelscan/modelscan.py @@ -154,26 +154,30 @@ def scan( model_path = Path(pathlib_path) all_paths: List[Path] = [] - scanned_paths: List[Path] = [] for model in self._iterate_models(model_path): self._middleware_pipeline.run(model) - scanned = self._scan_source(model) - if scanned: - scanned_paths.append(model.get_source()) - + self._scan_source(model) all_paths.append(model.get_source()) - skipped_paths = list(set(all_paths) - set(scanned_paths)) - if skipped_paths: - for skipped_path in skipped_paths: - self._skipped.append( - ModelScanSkipped( - "ModelScan", - SkipCategories.SCAN_NOT_SUPPORTED, - f"Model Scan did not scan file", - str(skipped_path), - ) - ) + if self._skipped: + all_skipped_paths = [skipped.source for skipped in self._skipped] + for skipped in self._skipped: + main_file_path = skipped.source.split(":")[0] + if main_file_path == skipped.source: + continue + + # If main container is skipped, we only add its content to skipped but not the file itself + if main_file_path in all_skipped_paths: + self._skipped = [ + item for item in self._skipped if item.source != main_file_path + ] + + continue + + # If main container is scanned, we consider all files to be scanned + self._skipped = [ + item for item in self._skipped if item.source != skipped.source + ] return self._generate_results() @@ -217,6 +221,18 @@ def _scan_source( else: self._scanned.append(str(model.get_source())) + if not scanned: + all_skipped_files = [skipped.source for skipped in self._skipped] + if str(model.get_source()) not in all_skipped_files: + self._skipped.append( + ModelScanSkipped( + "ModelScan", + SkipCategories.SCAN_NOT_SUPPORTED, + f"Model Scan did not scan file", + str(model.get_source()), + ) + ) + return scanned def _generate_results(self) -> Dict[str, Any]: diff --git a/tests/test_modelscan.py b/tests/test_modelscan.py index 0136841..cc9963d 100644 --- a/tests/test_modelscan.py +++ b/tests/test_modelscan.py @@ -483,7 +483,6 @@ def test_scan_pytorch(pytorch_file_path: Any) -> None: for skipped_file in results["summary"]["skipped"]["skipped_files"] ] ) == { - "safe_zip_pytorch.pt", "safe_zip_pytorch.pt:safe_zip_pytorch/byteorder", "safe_zip_pytorch.pt:safe_zip_pytorch/version", "safe_zip_pytorch.pt:safe_zip_pytorch/.data/serialization_id", @@ -522,7 +521,6 @@ def test_scan_pytorch(pytorch_file_path: Any) -> None: for skipped_file in results["summary"]["skipped"]["skipped_files"] ] ) == { - "unsafe_zip_pytorch.pt", "unsafe_zip_pytorch.pt:unsafe_zip_pytorch/byteorder", "unsafe_zip_pytorch.pt:unsafe_zip_pytorch/version", "unsafe_zip_pytorch.pt:unsafe_zip_pytorch/.data/serialization_id", @@ -1295,19 +1293,7 @@ def test_scan_keras(keras_file_path: Any, file_extension: str) -> None: f"safe{file_extension}" ] - if file_extension == ".keras": - assert set( - [ - skipped_file["source"] - for skipped_file in results["summary"]["skipped"]["skipped_files"] - ] - ) == { - f"safe{file_extension}:metadata.json", - f"safe{file_extension}:config.json", - f"safe{file_extension}:model.weights.h5", - } - else: - assert results["summary"]["skipped"]["skipped_files"] == [] + assert results["summary"]["skipped"]["skipped_files"] == [] assert results["errors"] == []