Skip to content

Commit

Permalink
Merge branch 'main' into scanners-from-settings
Browse files Browse the repository at this point in the history
  • Loading branch information
swashko committed Jan 12, 2024
2 parents 4291fbc + b4797a6 commit c093bc5
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 45 deletions.
3 changes: 3 additions & 0 deletions modelscan/issues.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,6 @@ def output_json(self) -> Dict[str, str]:
"source": f"{str(self.source)}",
"scanner": f"{self.scanner}",
}

def __repr__(self) -> str:
return f"<OperatorIssueDetails(module={self.module}, operator={self.operator}, source={str(self.source)})>"
14 changes: 5 additions & 9 deletions modelscan/modelscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
# Output
self._issues = Issues()
self._errors: List[Error] = []
self._init_errors: List[Error] = []
self._skipped: List[str] = []
self._scanned: List[str] = []
self._input_path: str = ""
Expand All @@ -51,23 +52,17 @@ def _load_scanners(self) -> None:
scanner_classes[scanner_path] = scanner_class
except Exception as e:
logger.error(f"Error importing scanner {scanner_path}")
self._errors.append(
self._init_errors.append(
ModelScanError(
scanner_path, f"Error importing scanner {scanner_path}: {e}"
)
)

scanners_to_run: List[ScanBase] = []
for scanner_class, scanner in scanner_classes.items():
dep_error = scanner.handle_binary_dependencies()
if dep_error:
logger.info(
f"Skipping {scanner.full_name()} as it is missing dependencies"
)
self._errors.append(dep_error)
else:
is_enabled: bool = self._settings["scanners"][scanner_class]["enabled"]
if is_enabled:
scanners_to_run.append(scanner)

self._scanners_to_run = scanners_to_run

def scan(
Expand All @@ -76,6 +71,7 @@ def scan(
) -> Dict[str, Any]:
self._issues = Issues()
self._errors = []
self._errors.extend(self._init_errors)
self._skipped = []
self._scanned = []
self._input_path = str(path)
Expand Down
4 changes: 4 additions & 0 deletions modelscan/scanners/h5/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def scan(
):
return None

dep_error = self.handle_binary_dependencies()
if dep_error:
return ScanResults([], [dep_error])

if data:
logger.warning(
"H5 scanner got data bytes. It only support direct file scanning."
Expand Down
4 changes: 4 additions & 0 deletions modelscan/scanners/keras/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def scan(
):
return None

dep_error = self.handle_binary_dependencies()
if dep_error:
return ScanResults([], [dep_error])

try:
with zipfile.ZipFile(data or source, "r") as zip:
file_names = zip.namelist()
Expand Down
4 changes: 4 additions & 0 deletions modelscan/scanners/saved_model/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ def scan(
):
return None

dep_error = self.handle_binary_dependencies()
if dep_error:
return ScanResults([], [dep_error])

if data:
results = self._scan(source, data)

Expand Down
8 changes: 7 additions & 1 deletion modelscan/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
"exec",
"open",
"breakpoint",
"__import__",
], # Pickle versions 0, 1, 2 have those function under '__builtin__'
"builtins": [
"eval",
Expand All @@ -62,6 +63,7 @@
"exec",
"open",
"breakpoint",
"__import__",
], # Pickle versions 3, 4 have those function under 'builtins'
"runpy": "*",
"os": "*",
Expand All @@ -70,7 +72,11 @@
"socket": "*",
"subprocess": "*",
"sys": "*",
"operator": "attrgetter", # Ex of code execution: operator.attrgetter("system")(__import__("os"))("echo pwned")
"operator": [
"attrgetter", # Ex of code execution: operator.attrgetter("system")(__import__("os"))("echo pwned")
],
"pty": "*",
"pickle": "*",
},
"HIGH": {
"webbrowser": "*", # Includes webbrowser.open()
Expand Down
66 changes: 31 additions & 35 deletions modelscan/tools/picklescanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,45 +134,41 @@ def scan_pickle_bytes(
)

logger.debug("Global imports in %s: %s", source, raw_globals)
severities = {
"CRITICAL": IssueSeverity.CRITICAL,
"HIGH": IssueSeverity.HIGH,
"MEDIUM": IssueSeverity.MEDIUM,
"LOW": IssueSeverity.LOW,
}

for rg in raw_globals:
global_module, global_name, severity = rg[0], rg[1], None
unsafe_critical_filter = settings["unsafe_globals"]["CRITICAL"].get(
global_module
)
unsafe_high_filter = settings["unsafe_globals"]["HIGH"].get(global_module)
unsafe_medium_filter = settings["unsafe_globals"]["MEDIUM"].get(global_module)
unsafe_low_filter = settings["unsafe_globals"]["LOW"].get(global_module)
if unsafe_critical_filter is not None and (
unsafe_critical_filter == "*" or global_name in unsafe_critical_filter
):
severity = IssueSeverity.CRITICAL

elif unsafe_high_filter is not None and (
unsafe_high_filter == "*" or global_name in unsafe_high_filter
):
severity = IssueSeverity.HIGH
elif unsafe_medium_filter is not None and (
unsafe_medium_filter == "*" or global_name in unsafe_medium_filter
):
severity = IssueSeverity.MEDIUM
elif unsafe_low_filter is not None and (
unsafe_low_filter == "*" or global_name in unsafe_low_filter
):
severity = IssueSeverity.LOW
elif "unknown" in global_module or "unknown" in global_name:
severity = IssueSeverity.MEDIUM
else:
continue
issues.append(
Issue(
code=IssueCode.UNSAFE_OPERATOR,
severity=severity,
details=OperatorIssueDetails(
module=global_module, operator=global_name, source=source
),
for severity_name in severities:
if global_module not in settings["unsafe_globals"][severity_name]:
continue
filter = settings["unsafe_globals"][severity_name][global_module]
if filter == "*":
severity = severities[severity_name]
break
for filter_value in filter:
if filter_value in global_name:
severity = severities[severity_name]
break
else:
continue
break
if "unknown" in global_module or "unknown" in global_name:
severity = IssueSeverity.CRITICAL # we must assume it is RCE
if severity is not None:
issues.append(
Issue(
code=IssueCode.UNSAFE_OPERATOR,
severity=severity,
details=OperatorIssueDetails(
module=global_module, operator=global_name, source=source
),
)
)
)
return ScanResults(issues, [])


Expand Down
88 changes: 88 additions & 0 deletions tests/test_modelscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,52 @@ def __reduce__(self) -> Any:
return sys.exit, (0,)


def malicious12_gen() -> bytes:
p = pickle.PROTO + b"\x05"

# stack = [pickle.loads]
p += pickle.GLOBAL + b"pickle\nloads\n"

# stack = [pickle.loads, p2]
p2 = (
pickle.PROTO
+ b"\x05"
+ pickle.GLOBAL
+ b"os\nsystem\n"
+ pickle.UNICODE
+ b"echo pwned!!!\n"
+ pickle.TUPLE1
+ pickle.REDUCE
+ pickle.STOP
)
p += pickle.BINBYTES + len(p2).to_bytes(4, "little") + p2

# stack = [pickle.loads, (p2,)]
p += pickle.TUPLE1

# stack = [pickle.loads(p2)]
p += pickle.REDUCE

# return None
p += pickle.POP
p += pickle.NONE

p += pickle.STOP
return p


def malicious13_gen() -> bytes:
p = pickle.PROTO + b"\x05"

p += pickle.GLOBAL + b"builtins\neval.__call__\n"
p += pickle.UNICODE + b'__import__("os").system("echo pwned!!!")\n'
p += pickle.TUPLE1
p += pickle.REDUCE

p += pickle.STOP
return p


def initialize_pickle_file(path: str, obj: Any, version: int) -> None:
if not os.path.exists(path):
with open(path, "wb") as file:
Expand Down Expand Up @@ -209,6 +255,10 @@ def file_path(tmp_path_factory: Any) -> Any:
)
initialize_data_file(f"{tmp}/data/malicious10.pkl", malicious10_pickle_bytes)

initialize_data_file(f"{tmp}/data/malicious12.pkl", malicious12_gen())

initialize_data_file(f"{tmp}/data/malicious13.pkl", malicious13_gen())

return tmp


Expand Down Expand Up @@ -643,6 +693,30 @@ def test_scan_pickle_operators(file_path: Any) -> None:
malicious11 = ModelScan()
malicious11.scan(Path(f"{file_path}/data/malicious11.pkl"))
assert malicious11.issues.all_issues == expected_malicious11
expected_malicious12 = [
Issue(
IssueCode.UNSAFE_OPERATOR,
IssueSeverity.CRITICAL,
OperatorIssueDetails(
"pickle", "loads", f"{file_path}/data/malicious12.pkl"
),
)
]
malicious12 = ModelScan()
malicious12.scan(Path(f"{file_path}/data/malicious12.pkl"))
assert malicious12.issues.all_issues == expected_malicious12
expected_malicious13 = [
Issue(
IssueCode.UNSAFE_OPERATOR,
IssueSeverity.CRITICAL,
OperatorIssueDetails(
"builtins", "eval.__call__", f"{file_path}/data/malicious13.pkl"
),
)
]
malicious13 = ModelScan()
malicious13.scan(Path(f"{file_path}/data/malicious13.pkl"))
assert malicious13.issues.all_issues == expected_malicious13


def test_scan_directory_path(file_path: str) -> None:
Expand Down Expand Up @@ -807,6 +881,20 @@ def test_scan_directory_path(file_path: str) -> None:
IssueSeverity.CRITICAL,
OperatorIssueDetails("os", "system", f"{file_path}/data/malicious11.pkl"),
),
Issue(
IssueCode.UNSAFE_OPERATOR,
IssueSeverity.CRITICAL,
OperatorIssueDetails(
"pickle", "loads", f"{file_path}/data/malicious12.pkl"
),
),
Issue(
IssueCode.UNSAFE_OPERATOR,
IssueSeverity.CRITICAL,
OperatorIssueDetails(
"builtins", "eval.__call__", f"{file_path}/data/malicious13.pkl"
),
),
}
ms = ModelScan()
p = Path(f"{file_path}/data/")
Expand Down

0 comments on commit c093bc5

Please sign in to comment.