Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
get scanners from settings, improve settings schema
Browse files Browse the repository at this point in the history
swashko committed Jan 11, 2024
1 parent 3bd19eb commit 4291fbc
Showing 7 changed files with 85 additions and 85 deletions.
61 changes: 32 additions & 29 deletions modelscan/modelscan.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@
import zipfile
import importlib

from modelscan.settings import DEFAULT_SCANNERS, DEFAULT_SETTINGS
from modelscan.settings import DEFAULT_SETTINGS

from pathlib import Path
from typing import List, Union, Optional, IO, Dict, Tuple, Any
@@ -20,7 +20,6 @@
class ModelScan:
def __init__(
self,
scanners_to_load: List[str] = DEFAULT_SCANNERS,
settings: Dict[str, Any] = DEFAULT_SETTINGS,
) -> None:
# Output
@@ -33,38 +32,42 @@ def __init__(
# Scanners
self._scanners_to_run: List[ScanBase] = []
self._settings: Dict[str, Any] = settings
self._load_scanners(scanners_to_load)
self._load_scanners()

def _load_scanners(self, scanners_to_load: List[str]) -> None:
def _load_scanners(self) -> None:
scanner_classes: Dict[str, ScanBase] = {}
for scanner_path in scanners_to_load:
try:
(modulename, classname) = scanner_path.rsplit(".", 1)
imported_module = importlib.import_module(
name=modulename, package=classname
)
scanner_class = getattr(imported_module, classname)
scanner_classes[scanner_path] = scanner_class
except Exception as e:
logger.error(f"Error importing scanner {scanner_path}")
self._errors.append(
ModelScanError(
scanner_path, f"Error importing scanner {scanner_path}: {e}"

for scanner_path, scanner_settings in self._settings["scanners"].items():
if (
"enabled" in scanner_settings.keys()
and self._settings["scanners"][scanner_path]["enabled"]
):
try:
(modulename, classname) = scanner_path.rsplit(".", 1)
imported_module = importlib.import_module(
name=modulename, package=classname
)
scanner_class = getattr(imported_module, classname)
scanner_classes[scanner_path] = scanner_class
except Exception as e:
logger.error(f"Error importing scanner {scanner_path}")
self._errors.append(
ModelScanError(
scanner_path, f"Error importing scanner {scanner_path}: {e}"
)
)
)

scanners_to_run: List[ScanBase] = []
for scanner_class, scanner in scanner_classes.items():
is_enabled: bool = self._settings["scanners"][scanner_class]["enabled"]
if is_enabled:
dep_error = scanner.handle_binary_dependencies()
if dep_error:
logger.info(
f"Skipping {scanner.full_name()} as it is missing dependencies"
)
self._errors.append(dep_error)
else:
scanners_to_run.append(scanner)
dep_error = scanner.handle_binary_dependencies()
if dep_error:
logger.info(
f"Skipping {scanner.full_name()} as it is missing dependencies"
)
self._errors.append(dep_error)
else:
scanners_to_run.append(scanner)

self._scanners_to_run = scanners_to_run

def scan(
@@ -114,7 +117,7 @@ def _scan_source(
) -> bool:
scanned = False
for scan_class in self._scanners_to_run:
scanner = scan_class(self._settings["scanners"]) # type: ignore[operator]
scanner = scan_class(self._settings) # type: ignore[operator]
scan_results = scanner.scan(
source=source,
data=data,
2 changes: 1 addition & 1 deletion modelscan/scanners/h5/scan.py
Original file line number Diff line number Diff line change
@@ -25,7 +25,7 @@ def scan(
) -> Optional[ScanResults]:
if (
not Path(source).suffix
in self._settings[H5Scan.full_name()]["supported_extensions"]
in self._settings["scanners"][H5Scan.full_name()]["supported_extensions"]
):
return None

2 changes: 1 addition & 1 deletion modelscan/scanners/keras/scan.py
Original file line number Diff line number Diff line change
@@ -21,7 +21,7 @@ def scan(
) -> Optional[ScanResults]:
if (
not Path(source).suffix
in self._settings[KerasScan.full_name()]["supported_extensions"]
in self._settings["scanners"][KerasScan.full_name()]["supported_extensions"]
):
return None

10 changes: 7 additions & 3 deletions modelscan/scanners/pickle/scan.py
Original file line number Diff line number Diff line change
@@ -20,7 +20,9 @@ def scan(
) -> Optional[ScanResults]:
if (
not Path(source).suffix
in self._settings[PyTorchScan.full_name()]["supported_extensions"]
in self._settings["scanners"][PyTorchScan.full_name()][
"supported_extensions"
]
):
return None

@@ -52,7 +54,7 @@ def scan(
) -> Optional[ScanResults]:
if (
not Path(source).suffix
in self._settings[NumpyScan.full_name()]["supported_extensions"]
in self._settings["scanners"][NumpyScan.full_name()]["supported_extensions"]
):
return None

@@ -81,7 +83,9 @@ def scan(
) -> Optional[ScanResults]:
if (
not Path(source).suffix
in self._settings[PickleScan.full_name()]["supported_extensions"]
in self._settings["scanners"][PickleScan.full_name()][
"supported_extensions"
]
):
return None

10 changes: 6 additions & 4 deletions modelscan/scanners/saved_model/scan.py
Original file line number Diff line number Diff line change
@@ -31,7 +31,9 @@ def scan(
) -> Optional[ScanResults]:
if (
not Path(source).suffix
in self._settings[SavedModelScan.full_name()]["supported_extensions"]
in self._settings["scanners"][SavedModelScan.full_name()][
"supported_extensions"
]
):
return None

@@ -111,9 +113,9 @@ def _check_for_unsafe_tf_keras_operator(
source: Union[str, Path],
settings: Dict[str, Any],
) -> ScanResults:
unsafe_operators: Dict[str, Any] = settings[SavedModelScan.full_name()][
"unsafe_tf_keras_operators"
]
unsafe_operators: Dict[str, Any] = settings["scanners"][
SavedModelScan.full_name()
]["unsafe_tf_keras_operators"]

issues: List[Issue] = []
all_operators = tensorflow.raw_ops.__dict__.keys()
83 changes: 37 additions & 46 deletions modelscan/settings.py
Original file line number Diff line number Diff line change
@@ -3,15 +3,6 @@
from typing import Any


DEFAULT_SCANNERS = [
"modelscan.scanners.H5Scan",
"modelscan.scanners.KerasScan",
"modelscan.scanners.SavedModelScan",
"modelscan.scanners.NumpyScan",
"modelscan.scanners.PickleScan",
"modelscan.scanners.PyTorchScan",
]

DEFAULT_SETTINGS = {
"supported_zip_extensions": [".zip", ".npz"],
"scanners": {
@@ -51,44 +42,44 @@
"enabled": True,
"supported_extensions": [".bin", ".pt", ".pth", ".ckpt"],
},
"unsafe_globals": {
"CRITICAL": {
"__builtin__": [
"eval",
"compile",
"getattr",
"apply",
"exec",
"open",
"breakpoint",
], # Pickle versions 0, 1, 2 have those function under '__builtin__'
"builtins": [
"eval",
"compile",
"getattr",
"apply",
"exec",
"open",
"breakpoint",
], # Pickle versions 3, 4 have those function under 'builtins'
"runpy": "*",
"os": "*",
"nt": "*", # Alias for 'os' on Windows. Includes os.system()
"posix": "*", # Alias for 'os' on Linux. Includes os.system()
"socket": "*",
"subprocess": "*",
"sys": "*",
"operator": "attrgetter", # Ex of code execution: operator.attrgetter("system")(__import__("os"))("echo pwned")
},
"HIGH": {
"webbrowser": "*", # Includes webbrowser.open()
"httplib": "*", # Includes http.client.HTTPSConnection()
"requests.api": "*",
"aiohttp.client": "*",
},
"MEDIUM": {},
"LOW": {},
},
"unsafe_globals": {
"CRITICAL": {
"__builtin__": [
"eval",
"compile",
"getattr",
"apply",
"exec",
"open",
"breakpoint",
], # Pickle versions 0, 1, 2 have those function under '__builtin__'
"builtins": [
"eval",
"compile",
"getattr",
"apply",
"exec",
"open",
"breakpoint",
], # Pickle versions 3, 4 have those function under 'builtins'
"runpy": "*",
"os": "*",
"nt": "*", # Alias for 'os' on Windows. Includes os.system()
"posix": "*", # Alias for 'os' on Linux. Includes os.system()
"socket": "*",
"subprocess": "*",
"sys": "*",
"operator": "attrgetter", # Ex of code execution: operator.attrgetter("system")(__import__("os"))("echo pwned")
},
"HIGH": {
"webbrowser": "*", # Includes webbrowser.open()
"httplib": "*", # Includes http.client.HTTPSConnection()
"requests.api": "*",
"aiohttp.client": "*",
},
"MEDIUM": {},
"LOW": {},
},
}

2 changes: 1 addition & 1 deletion tests/test_modelscan.py
Original file line number Diff line number Diff line change
@@ -32,7 +32,7 @@
)
from modelscan.settings import DEFAULT_SETTINGS

settings: Dict[str, Any] = DEFAULT_SETTINGS["scanners"] # type: ignore[assignment]
settings: Dict[str, Any] = DEFAULT_SETTINGS


class Malicious1:

0 comments on commit 4291fbc

Please sign in to comment.