Skip to content

Commit

Permalink
Skipping files with invalid magic number
Browse files Browse the repository at this point in the history
  • Loading branch information
mehrinkiani committed Feb 23, 2024
1 parent 3a6b26b commit 9195dae
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 52 deletions.
16 changes: 9 additions & 7 deletions modelscan/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@ class ErrorCategories(Enum):
PATH = 3
NESTED_ZIP = 4
PICKLE_GENOPS = 5
MAGIC_NUMBER = 6
JSON_DECODE = 7
JSON_DECODE = 6


class Error:
scan_name: str
category: ErrorCategories
message: Optional[str]
message: str
source: Optional[str]

def __init__(self) -> None:
Expand All @@ -30,13 +29,16 @@ def __init__(
self,
scan_name: str,
category: ErrorCategories,
message: Optional[str] = None,
message: str,
source: Optional[str] = None,
) -> None:
self.scan_name = scan_name
self.category = category
self.message = message or "None"
self.source = str(source)
self.message = message
self.source = source

def __str__(self) -> str:
return f"The following error was raised during a {self.scan_name} scan: \n{self.message}"
if self.source:
return f"The following error was raised during a {self.scan_name} scan of file {self.source}: \n{self.message}"
else:
return f"The following error was raised during a {self.scan_name} scan: \n{self.message}"
66 changes: 34 additions & 32 deletions modelscan/modelscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,48 +218,50 @@ def _generate_results(self) -> Dict[str, Any]:
report["summary"]["timestamp"] = datetime.now().isoformat()

report["summary"]["scanned"] = {"total_scanned": len(self._scanned)}
report["summary"]["scanned"]["scanned_files"] = [
str(Path(file_name).relative_to(Path(absolute_path)))
for file_name in self._scanned
]

report["issues"] = [
issue.details.output_json() for issue in self._issues.all_issues
]

for issue in report["issues"]:
issue["source"] = str(
Path(issue["source"]).relative_to(Path(absolute_path))
)

all_errors = []
if self._scanned:
report["summary"]["scanned"]["scanned_files"] = [
str(Path(file_name).relative_to(Path(absolute_path)))
for file_name in self._scanned
]

for error in self._errors:
error_information = {}
error_information["category"] = str(error.category.name)
if error.message is not None:
error_information["description"] = error.message
if hasattr(error, "source"):
error_information["source"] = str(
Path(str(error.source)).relative_to(Path(absolute_path))
if self._issues.all_issues:
report["issues"] = [
issue.details.output_json() for issue in self._issues.all_issues
]

for issue in report["issues"]:
issue["source"] = str(
Path(issue["source"]).relative_to(Path(absolute_path))
)
all_errors = []
if self._errors:
for error in self._errors:
error_information = {}
error_information["category"] = str(error.category.name)
if error.message:
error_information["description"] = error.message
if error.source is not None:
error_information["source"] = str(
Path(str(error.source)).relative_to(Path(absolute_path))
)

all_errors.append(error_information)
all_errors.append(error_information)

report["errors"] = all_errors

report["summary"]["skipped"] = {"total_skipped": len(self._skipped)}

all_skipped_files = []

for skipped_file in self._skipped:
skipped_file_information = {}
skipped_file_information["category"] = str(skipped_file.category.name)
skipped_file_information["description"] = str(skipped_file.message)
skipped_file_information["source"] = str(
Path(skipped_file.source).relative_to(Path(absolute_path))
)
all_skipped_files.append(skipped_file_information)
if self._skipped:
for skipped_file in self._skipped:
skipped_file_information = {}
skipped_file_information["category"] = str(skipped_file.category.name)
skipped_file_information["description"] = str(skipped_file.message)
skipped_file_information["source"] = str(
Path(skipped_file.source).relative_to(Path(absolute_path))
)
all_skipped_files.append(skipped_file_information)

report["summary"]["skipped"]["skipped_files"] = all_skipped_files

Expand Down
2 changes: 1 addition & 1 deletion modelscan/scanners/h5/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def scan(
ModelScanError(
self.name(),
ErrorCategories.DEPENDENCY,
f"To use {self.full_name()}, please install modelscan with h5py extras. 'pip install \"modelscan\[h5py]\"' if you are using pip.",
f"To use {self.full_name()}, please install modelscan with h5py extras. 'pip install \"modelscan[[ h5py]]\"' if you are using pip.",
)
],
[],
Expand Down
2 changes: 1 addition & 1 deletion modelscan/scanners/keras/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def scan(
ModelScanError(
self.name(),
ErrorCategories.DEPENDENCY,
f"To use {self.full_name()}, please install modelscan with dependencies.",
f"To use {self.full_name()}, please install modelscan with tensorflow extras. 'pip install \"modelscan[[ tensorflow]]\"' if you are using pip.",
)
],
[],
Expand Down
2 changes: 1 addition & 1 deletion modelscan/scanners/saved_model/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def scan(
ModelScanError(
self.name(),
ErrorCategories.DEPENDENCY,
f"To use {self.full_name()}, please install modelscan with tensorflow extras. 'pip install \"modelscan\[tensorflow]\"' if you are using pip.",
f"To use {self.full_name()}, please install modelscan with tensorflow extras. 'pip install \"modelscan[[ tensorflow ]]\"' if you are using pip.",
)
],
[],
Expand Down
3 changes: 2 additions & 1 deletion modelscan/skip.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class SkipCategories(Enum):
MODEL_CONFIG = 3
H5_DATA = 4
NOT_IMPLEMENTED = 5
MAGIC_NUMBER = 6


class Skip:
Expand All @@ -40,7 +41,7 @@ def __init__(
) -> None:
self.scan_name = scan_name
self.category = category
self.message = message or "None"
self.message = message
self.source = str(source)

def __str__(self) -> str:
Expand Down
10 changes: 5 additions & 5 deletions modelscan/tools/picklescanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def scan_numpy(
)
],
)
# raise NotImplementedError("Scanning of .npz files is not implemented yet")

elif magic == np.lib.format.MAGIC_PREFIX:
# .npy file
version = np.lib.format.read_magic(data) # type: ignore[no-untyped-call]
Expand Down Expand Up @@ -242,15 +242,15 @@ def scan_pytorch(
magic = get_magic_number(data)
if magic != MAGIC_NUMBER:
return ScanResults(
[],
[],
[
ModelScanError(
ModelScanSkipped(
scan_name,
ErrorCategories.MAGIC_NUMBER,
SkipCategories.MAGIC_NUMBER,
f"Invalid magic number",
str(source),
),
)
],
[],
)
return scan_pickle_bytes(data, source, settings, scan_name, multiple_pickles=False)
7 changes: 3 additions & 4 deletions tests/test_modelscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
scan_numpy,
)

from modelscan.error import ErrorCategories
from modelscan.skip import SkipCategories
from modelscan.settings import DEFAULT_SETTINGS

settings: Dict[str, Any] = DEFAULT_SETTINGS
Expand Down Expand Up @@ -464,15 +464,14 @@ def test_scan_pytorch(pytorch_file_path: Any) -> None:
ms = ModelScan()
results = ms.scan(Path(f"{pytorch_file_path}/bad_pytorch.pt"))

assert results["errors"] == [
assert results["summary"]["skipped"]["skipped_files"] == [
{
"category": ErrorCategories.MAGIC_NUMBER.name,
"category": SkipCategories.MAGIC_NUMBER.name,
"description": f"Invalid magic number",
"source": f"bad_pytorch.pt",
}
]
assert ms.issues.all_issues == []
assert [error.scan_name for error in ms.errors] == ["pytorch"]

results = ms.scan(Path(f"{pytorch_file_path}/safe_zip_pytorch.pt"))
assert results["summary"]["scanned"]["scanned_files"] == [
Expand Down

0 comments on commit 9195dae

Please sign in to comment.