Skip to content

Commit

Permalink
Get scanners to load from settings and improve settings schema (#77)
Browse files Browse the repository at this point in the history
  • Loading branch information
swashko authored Jan 12, 2024
1 parent b4797a6 commit 3b20681
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 89 deletions.
51 changes: 24 additions & 27 deletions modelscan/modelscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import zipfile
import importlib

from modelscan.settings import DEFAULT_SCANNERS, DEFAULT_SETTINGS
from modelscan.settings import DEFAULT_SETTINGS

from pathlib import Path
from typing import List, Union, Optional, IO, Dict, Tuple, Any
Expand All @@ -20,7 +20,6 @@
class ModelScan:
def __init__(
self,
scanners_to_load: List[str] = DEFAULT_SCANNERS,
settings: Dict[str, Any] = DEFAULT_SETTINGS,
) -> None:
# Output
Expand All @@ -34,32 +33,30 @@ def __init__(
# Scanners
self._scanners_to_run: List[ScanBase] = []
self._settings: Dict[str, Any] = settings
self._load_scanners(scanners_to_load)

def _load_scanners(self, scanners_to_load: List[str]) -> None:
scanner_classes: Dict[str, ScanBase] = {}
for scanner_path in scanners_to_load:
try:
(modulename, classname) = scanner_path.rsplit(".", 1)
imported_module = importlib.import_module(
name=modulename, package=classname
)
scanner_class = getattr(imported_module, classname)
scanner_classes[scanner_path] = scanner_class
except Exception as e:
logger.error(f"Error importing scanner {scanner_path}")
self._init_errors.append(
ModelScanError(
scanner_path, f"Error importing scanner {scanner_path}: {e}"
self._load_scanners()

def _load_scanners(self) -> None:
for scanner_path, scanner_settings in self._settings["scanners"].items():
if (
"enabled" in scanner_settings.keys()
and self._settings["scanners"][scanner_path]["enabled"]
):
try:
(modulename, classname) = scanner_path.rsplit(".", 1)
imported_module = importlib.import_module(
name=modulename, package=classname
)
)

scanners_to_run: List[ScanBase] = []
for scanner_class, scanner in scanner_classes.items():
is_enabled: bool = self._settings["scanners"][scanner_class]["enabled"]
if is_enabled:
scanners_to_run.append(scanner)
self._scanners_to_run = scanners_to_run
scanner_class: ScanBase = getattr(imported_module, classname)
self._scanners_to_run.append(scanner_class)

except Exception as e:
logger.error(f"Error importing scanner {scanner_path}")
self._init_errors.append(
ModelScanError(
scanner_path, f"Error importing scanner {scanner_path}: {e}"
)
)

def scan(
self,
Expand Down Expand Up @@ -109,7 +106,7 @@ def _scan_source(
) -> bool:
scanned = False
for scan_class in self._scanners_to_run:
scanner = scan_class(self._settings["scanners"]) # type: ignore[operator]
scanner = scan_class(self._settings) # type: ignore[operator]
scan_results = scanner.scan(
source=source,
data=data,
Expand Down
2 changes: 1 addition & 1 deletion modelscan/scanners/h5/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def scan(
) -> Optional[ScanResults]:
if (
not Path(source).suffix
in self._settings[H5Scan.full_name()]["supported_extensions"]
in self._settings["scanners"][H5Scan.full_name()]["supported_extensions"]
):
return None

Expand Down
2 changes: 1 addition & 1 deletion modelscan/scanners/keras/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def scan(
) -> Optional[ScanResults]:
if (
not Path(source).suffix
in self._settings[KerasScan.full_name()]["supported_extensions"]
in self._settings["scanners"][KerasScan.full_name()]["supported_extensions"]
):
return None

Expand Down
10 changes: 7 additions & 3 deletions modelscan/scanners/pickle/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ def scan(
) -> Optional[ScanResults]:
if (
not Path(source).suffix
in self._settings[PyTorchScan.full_name()]["supported_extensions"]
in self._settings["scanners"][PyTorchScan.full_name()][
"supported_extensions"
]
):
return None

Expand Down Expand Up @@ -52,7 +54,7 @@ def scan(
) -> Optional[ScanResults]:
if (
not Path(source).suffix
in self._settings[NumpyScan.full_name()]["supported_extensions"]
in self._settings["scanners"][NumpyScan.full_name()]["supported_extensions"]
):
return None

Expand Down Expand Up @@ -81,7 +83,9 @@ def scan(
) -> Optional[ScanResults]:
if (
not Path(source).suffix
in self._settings[PickleScan.full_name()]["supported_extensions"]
in self._settings["scanners"][PickleScan.full_name()][
"supported_extensions"
]
):
return None

Expand Down
10 changes: 6 additions & 4 deletions modelscan/scanners/saved_model/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ def scan(
) -> Optional[ScanResults]:
if (
not Path(source).suffix
in self._settings[SavedModelScan.full_name()]["supported_extensions"]
in self._settings["scanners"][SavedModelScan.full_name()][
"supported_extensions"
]
):
return None

Expand Down Expand Up @@ -115,9 +117,9 @@ def _check_for_unsafe_tf_keras_operator(
source: Union[str, Path],
settings: Dict[str, Any],
) -> ScanResults:
unsafe_operators: Dict[str, Any] = settings[SavedModelScan.full_name()][
"unsafe_tf_keras_operators"
]
unsafe_operators: Dict[str, Any] = settings["scanners"][
SavedModelScan.full_name()
]["unsafe_tf_keras_operators"]

issues: List[Issue] = []
all_operators = tensorflow.raw_ops.__dict__.keys()
Expand Down
95 changes: 43 additions & 52 deletions modelscan/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,6 @@
from typing import Any


DEFAULT_SCANNERS = [
"modelscan.scanners.H5Scan",
"modelscan.scanners.KerasScan",
"modelscan.scanners.SavedModelScan",
"modelscan.scanners.NumpyScan",
"modelscan.scanners.PickleScan",
"modelscan.scanners.PyTorchScan",
]

DEFAULT_SETTINGS = {
"supported_zip_extensions": [".zip", ".npz"],
"scanners": {
Expand Down Expand Up @@ -51,50 +42,50 @@
"enabled": True,
"supported_extensions": [".bin", ".pt", ".pth", ".ckpt"],
},
"unsafe_globals": {
"CRITICAL": {
"__builtin__": [
"eval",
"compile",
"getattr",
"apply",
"exec",
"open",
"breakpoint",
"__import__",
], # Pickle versions 0, 1, 2 have those function under '__builtin__'
"builtins": [
"eval",
"compile",
"getattr",
"apply",
"exec",
"open",
"breakpoint",
"__import__",
], # Pickle versions 3, 4 have those function under 'builtins'
"runpy": "*",
"os": "*",
"nt": "*", # Alias for 'os' on Windows. Includes os.system()
"posix": "*", # Alias for 'os' on Linux. Includes os.system()
"socket": "*",
"subprocess": "*",
"sys": "*",
"operator": [
"attrgetter", # Ex of code execution: operator.attrgetter("system")(__import__("os"))("echo pwned")
],
"pty": "*",
"pickle": "*",
},
"HIGH": {
"webbrowser": "*", # Includes webbrowser.open()
"httplib": "*", # Includes http.client.HTTPSConnection()
"requests.api": "*",
"aiohttp.client": "*",
},
"MEDIUM": {},
"LOW": {},
},
"unsafe_globals": {
"CRITICAL": {
"__builtin__": [
"eval",
"compile",
"getattr",
"apply",
"exec",
"open",
"breakpoint",
"__import__",
], # Pickle versions 0, 1, 2 have those function under '__builtin__'
"builtins": [
"eval",
"compile",
"getattr",
"apply",
"exec",
"open",
"breakpoint",
"__import__",
], # Pickle versions 3, 4 have those function under 'builtins'
"runpy": "*",
"os": "*",
"nt": "*", # Alias for 'os' on Windows. Includes os.system()
"posix": "*", # Alias for 'os' on Linux. Includes os.system()
"socket": "*",
"subprocess": "*",
"sys": "*",
"operator": [
"attrgetter", # Ex of code execution: operator.attrgetter("system")(__import__("os"))("echo pwned")
],
"pty": "*",
"pickle": "*",
},
"HIGH": {
"webbrowser": "*", # Includes webbrowser.open()
"httplib": "*", # Includes http.client.HTTPSConnection()
"requests.api": "*",
"aiohttp.client": "*",
},
"MEDIUM": {},
"LOW": {},
},
}

Expand Down
2 changes: 1 addition & 1 deletion tests/test_modelscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
)
from modelscan.settings import DEFAULT_SETTINGS

settings: Dict[str, Any] = DEFAULT_SETTINGS["scanners"] # type: ignore[assignment]
settings: Dict[str, Any] = DEFAULT_SETTINGS


class Malicious1:
Expand Down

0 comments on commit 3b20681

Please sign in to comment.