diff --git a/modelscan/modelscan.py b/modelscan/modelscan.py index 228759b..4442f5e 100644 --- a/modelscan/modelscan.py +++ b/modelscan/modelscan.py @@ -91,11 +91,7 @@ def _iterate_models(self, model_path: Path) -> Generator[Model, None, None]: with Model(file) as model: yield model - if ( - not _is_zipfile(file, model.get_stream()) - and Path(file).suffix - not in self._settings["supported_zip_extensions"] - ): + if not _is_zipfile(file, model.get_stream()): continue try: @@ -114,7 +110,7 @@ def _iterate_models(self, model_path: Path) -> Generator[Model, None, None]: continue yield Model(file_name, file_io) - except zipfile.BadZipFile as e: + except (zipfile.BadZipFile, RuntimeError) as e: logger.debug( "Skipping zip file %s, due to error", str(model.get_source()), diff --git a/tests/data/password_protected.zip b/tests/data/password_protected.zip new file mode 100644 index 0000000..b1dd460 Binary files /dev/null and b/tests/data/password_protected.zip differ diff --git a/tests/test_modelscan.py b/tests/test_modelscan.py index 464d26c..7590e40 100644 --- a/tests/test_modelscan.py +++ b/tests/test_modelscan.py @@ -10,6 +10,7 @@ import dill import pytest import requests +import shutil import socket import subprocess import sys @@ -331,6 +332,10 @@ def file_path(tmp_path_factory: Any) -> Any: initialize_data_file(f"{tmp}/data/malicious14.pkl", malicious14_gen()) + shutil.copy( + f"{os.path.dirname(__file__)}/data/password_protected.zip", f"{tmp}/data/" + ) + return tmp @@ -1361,7 +1366,18 @@ def test_scan_directory_path(file_path: str) -> None: "benign0_v3.dill", "benign0_v4.dill", } - assert results["summary"]["skipped"]["skipped_files"] == [] + assert results["summary"]["skipped"]["skipped_files"] == [ + { + "category": "SCAN_NOT_SUPPORTED", + "description": "Model Scan did not scan file", + "source": "password_protected.zip", + }, + { + "category": "BAD_ZIP", + "description": "Skipping zip file due to error: File 'test.txt' is encrypted, password required for extraction", + "source": "password_protected.zip", + }, + ] assert results["errors"] == []