diff --git a/.gitignore b/.gitignore index 7e48834..1de7eb8 100644 --- a/.gitignore +++ b/.gitignore @@ -128,4 +128,8 @@ cython_debug/ .DS_Store .vscode/ -.idea/ \ No newline at end of file +.idea/ + +# Notebook Model Downloads +notebooks/PyTorchModels/ +pytorch-model-scan-results.json \ No newline at end of file diff --git a/modelscan/__init__.py b/modelscan/__init__.py index b420341..9868bc5 100644 --- a/modelscan/__init__.py +++ b/modelscan/__init__.py @@ -1,4 +1,5 @@ """CLI for scanning models""" + import logging from modelscan._version import __version__ diff --git a/modelscan/cli.py b/modelscan/cli.py index 230b083..a10fcb9 100644 --- a/modelscan/cli.py +++ b/modelscan/cli.py @@ -2,7 +2,7 @@ import sys import os from pathlib import Path -from typing import Optional, Dict, Any +from typing import Optional from tomlkit import parse import click @@ -87,7 +87,7 @@ def cli() -> None: help="Optional file name for output report", ) @cli.command( - help="[Default] Scan a model file or diretory for ability to execute suspicious actions. " + help="[Default] Scan a model file or directory for ability to execute suspicious actions. " ) # type: ignore @click.pass_context def scan( @@ -112,7 +112,7 @@ def scan( settings = DEFAULT_SETTINGS if settings_file_path and settings_file_path.is_file(): - with open(settings_file_path) as sf: + with open(settings_file_path, encoding="utf-8") as sf: settings = parse(sf.read()).unwrap() click.echo(f"Detected settings file. Using {settings_file_path}. \n") else: @@ -132,7 +132,7 @@ def scan( raise click.UsageError("Command line must include a path") # Report scan results - if reporting_format is not "custom": + if reporting_format != "custom": modelscan._settings["reporting"]["module"] = DEFAULT_REPORTING_MODULES[ reporting_format ] @@ -174,16 +174,17 @@ def create_settings(force: bool, location: Optional[str]) -> None: settings_path = location try: - open(settings_path) + open(settings_path, encoding="utf-8") if force: - with open(settings_path, "w") as settings_file: + with open(settings_path, mode="w", encoding="utf-8") as settings_file: settings_file.write(SettingsUtils.get_default_settings_as_toml()) else: logger.warning( - f"{settings_path} file already exists. Please use `--force` flag if you intend to overwrite it." + "%s file already exists. Please use `--force` flag if you intend to overwrite it.", + settings_path, ) except FileNotFoundError: - with open(settings_path, "w") as settings_file: + with open(settings_path, mode="w", encoding="utf-8") as settings_file: settings_file.write(SettingsUtils.get_default_settings_as_toml()) diff --git a/modelscan/issues.py b/modelscan/issues.py index 2e8da87..130318d 100644 --- a/modelscan/issues.py +++ b/modelscan/issues.py @@ -85,7 +85,7 @@ def print(self) -> None: if self.code == IssueCode.UNSAFE_OPERATOR: issue_description = "Unsafe operator" else: - logger.error(f"No issue description for issue code ${self.code}") + logger.error("No issue description for issue code %s", self.code) print(f"\n{issue_description} found:") print(f" - Severity: {self.severity.name}") diff --git a/modelscan/modelscan.py b/modelscan/modelscan.py index f222069..92b08e4 100644 --- a/modelscan/modelscan.py +++ b/modelscan/modelscan.py @@ -70,7 +70,7 @@ def _load_scanners(self) -> None: self._scanners_to_run.append(scanner_class) except Exception as e: - logger.error(f"Error importing scanner {scanner_path}") + logger.error("Error importing scanner %s", scanner_path) self._init_errors.append( ModelScanError( f"Error importing scanner {scanner_path}: {e}", @@ -79,12 +79,12 @@ def _load_scanners(self) -> None: def _iterate_models(self, model_path: Path) -> Generator[Model, None, None]: if not model_path.exists(): - logger.error(f"Path {model_path} does not exist") + logger.error("Path %s does not exist", model_path) self._errors.append(PathError("Path is not valid", model_path)) files = [model_path] if model_path.is_dir(): - logger.debug(f"Path {str(model_path)} is a directory") + logger.debug("Path %s is a directory", str(model_path)) files = [f for f in model_path.rglob("*") if Path.is_file(f)] for file in files: @@ -116,8 +116,8 @@ def _iterate_models(self, model_path: Path) -> Generator[Model, None, None]: yield Model(file_name, file_io) except zipfile.BadZipFile as e: logger.debug( - f"Skipping zip file {str(model.get_source())}, due to error", - e, + "Skipping zip file %s, due to error", + str(model.get_source()), exc_info=True, ) self._skipped.append( @@ -178,7 +178,10 @@ def _scan_source( scan_results = scanner.scan(model) except Exception as e: logger.error( - f"Error encountered from scanner {scanner.full_name()} with path {str(model.get_source())}: {e}" + "Error encountered from scanner %s with path %s: %s", + scanner.full_name(), + str(model.get_source()), + e, ) self._errors.append( ModelScanScannerError( @@ -192,7 +195,9 @@ def _scan_source( if scan_results is not None: scanned = True logger.info( - f"Scanning {model.get_source()} using {scanner.full_name()} model scan" + "Scanning %s using %s model scan", + model.get_source(), + scanner.full_name(), ) if scan_results.errors: self._errors.extend(scan_results.errors) @@ -212,7 +217,7 @@ def _scan_source( ModelScanSkipped( "ModelScan", SkipCategories.SCAN_NOT_SUPPORTED, - f"Model Scan did not scan file", + "Model Scan did not scan file", str(model.get_source()), ) ) @@ -328,7 +333,7 @@ def generate_report(self) -> Optional[str]: scan_report = report_class.generate(scan=self, settings=report_settings) except Exception as e: - logger.error(f"Error generating report using {reporting_module}: {e}") + logger.error("Error generating report using %s: %s", reporting_module, e) self._errors.append( ModelScanError(f"Error generating report using {reporting_module}: {e}") ) diff --git a/modelscan/reports.py b/modelscan/reports.py index 9dacb0e..f09159b 100644 --- a/modelscan/reports.py +++ b/modelscan/reports.py @@ -32,7 +32,7 @@ def generate( :param errors: Any errors that occurred during the scan. """ - raise NotImplemented + raise NotImplementedError class ConsoleReport(Report): @@ -46,7 +46,7 @@ def generate( total_issue_count = len(scan.issues.all_issues) if total_issue_count > 0: print(f"\nTotal Issues: {total_issue_count}") - print(f"\nTotal Issues By Severity:\n") + print("\nTotal Issues By Severity:\n") for severity in IssueSeverity: if severity.name in issues_by_severity: print( @@ -75,7 +75,7 @@ def generate( f"\nTotal skipped: {len(scan.skipped)} - run with --show-skipped to see the full list." ) if settings["show_skipped"]: - print(f"\nSkipped files list:\n") + print("\nSkipped files list:\n") for file_name in scan.skipped: print(str(file_name)) diff --git a/modelscan/scanners/h5/scan.py b/modelscan/scanners/h5/scan.py index f95ff8c..c398535 100644 --- a/modelscan/scanners/h5/scan.py +++ b/modelscan/scanners/h5/scan.py @@ -63,7 +63,7 @@ def _scan_keras_h5_file(self, model: Model) -> Optional[ScanResults]: [ JsonDecodeError( self.name(), - f"Not a valid JSON data", + "Not a valid JSON data", model, ) ], @@ -85,7 +85,7 @@ def _scan_keras_h5_file(self, model: Model) -> Optional[ScanResults]: ModelScanSkipped( self.name(), SkipCategories.MODEL_CONFIG, - f"Model Config not found", + "Model Config not found", str(model.get_source()), ) ], @@ -104,7 +104,7 @@ def _get_keras_h5_operator_names(self, model: Model) -> Optional[List[Any]]: with h5py.File(model.get_stream()) as model_hdf5: try: - if not "model_config" in model_hdf5.attrs.keys(): + if "model_config" not in model_hdf5.attrs.keys(): return None model_config = json.loads(model_hdf5.attrs.get("model_config", {})) diff --git a/modelscan/scanners/keras/scan.py b/modelscan/scanners/keras/scan.py index 50acd1a..2a7fb5e 100644 --- a/modelscan/scanners/keras/scan.py +++ b/modelscan/scanners/keras/scan.py @@ -65,7 +65,7 @@ def scan(self, model: Model) -> Optional[ScanResults]: [ ModelScanScannerError( self.name(), - f"Unable to scan .keras file", # Not sure if this is a representative message for ModelScanError + "Unable to scan .keras file", # Not sure if this is a representative message for ModelScanError model, ) ], @@ -89,7 +89,7 @@ def _scan_keras_config_file(self, model: Model) -> ScanResults: [ JsonDecodeError( self.name(), - f"Not a valid JSON data", + "Not a valid JSON data", model, ) ], diff --git a/modelscan/scanners/saved_model/scan.py b/modelscan/scanners/saved_model/scan.py index 58d1414..357b047 100644 --- a/modelscan/scanners/saved_model/scan.py +++ b/modelscan/scanners/saved_model/scan.py @@ -2,9 +2,8 @@ import json import logging -from pathlib import Path -from typing import IO, List, Set, Union, Optional, Dict, Any +from typing import List, Set, Optional, Dict, Any try: import tensorflow @@ -123,7 +122,7 @@ def _scan(self, model: Model) -> Optional[ScanResults]: [ JsonDecodeError( self.name(), - f"Not a valid JSON data", + "Not a valid JSON data", model, ) ], diff --git a/modelscan/skip.py b/modelscan/skip.py index 571fff9..27272e6 100644 --- a/modelscan/skip.py +++ b/modelscan/skip.py @@ -1,10 +1,6 @@ -import abc import logging from enum import Enum -from pathlib import Path -from typing import Any, List, Union, Dict, Optional -from collections import defaultdict logger = logging.getLogger("modelscan") diff --git a/modelscan/tools/picklescanner.py b/modelscan/tools/picklescanner.py index 5c1350b..8cae3b9 100644 --- a/modelscan/tools/picklescanner.py +++ b/modelscan/tools/picklescanner.py @@ -259,7 +259,7 @@ def scan_pytorch(model: Model, settings: Dict[str, Any]) -> ScanResults: ModelScanSkipped( scan_name, SkipCategories.MAGIC_NUMBER, - f"Invalid magic number", + "Invalid magic number", str(model.get_source()), ) ], diff --git a/modelscan/tools/utils.py b/modelscan/tools/utils.py index 58c986c..e309085 100644 --- a/modelscan/tools/utils.py +++ b/modelscan/tools/utils.py @@ -1,6 +1,4 @@ -import http.client import io -import urllib.parse from pathlib import Path from pickletools import genops # nosec from typing import IO, Optional, Union diff --git a/notebooks/pytorch_sentiment_analysis.ipynb b/notebooks/pytorch_sentiment_analysis.ipynb index 282439b..f30f15a 100644 --- a/notebooks/pytorch_sentiment_analysis.ipynb +++ b/notebooks/pytorch_sentiment_analysis.ipynb @@ -23,12 +23,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "modelscan, version 0.5.0\n" + "Note: you may need to restart the kernel to use updated packages.\n", + "modelscan, version 0.0.0\n" ] } ], "source": [ - "!pip install -q modelscan\n", + "%pip install -q modelscan\n", "!modelscan -v" ] }, @@ -36,26 +37,25 @@ "cell_type": "code", "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + ] + } + ], "source": [ - "!pip install -q torch==2.0.1\n", - "!pip install -q transformers==4.31.0\n", - "!pip install -q scipy==1.11.1" + "%pip install -q torch==2.0.1\n", + "%pip install -q transformers==4.31.0\n", + "%pip install -q scipy==1.11.1" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/mehrinkiani/mambaforge/envs/py310/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - }, { "name": "stdout", "output_type": "stream", @@ -66,10 +66,11 @@ ], "source": [ "import torch\n", - "import os \n", + "import os\n", "from utils.pytorch_sentiment_model import download_model, predict_sentiment\n", "from utils.pickle_codeinjection import PickleInject, get_payload\n", - "%env TOKENIZERS_PARALLELISM=false\n" + "\n", + "%env TOKENIZERS_PARALLELISM=false" ] }, { @@ -84,18 +85,20 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Save a model for sentiment analysis\n", - "model_directory = \"PyTorchModels\"\n", + "from typing import Final\n", + "\n", + "model_directory: Final[str] = \"PyTorchModels\"\n", "if not os.path.isdir(model_directory):\n", " os.mkdir(model_directory)\n", "\n", "safe_model_path = os.path.join(model_directory, \"safe_model.pt\")\n", "\n", - "sentiment_model = download_model(safe_model_path)" + "download_model(safe_model_path)" ] }, { @@ -107,7 +110,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -119,7 +122,9 @@ } ], "source": [ - "sentiment = predict_sentiment(\"Stock market was bearish today\", torch.load(safe_model_path))" + "sentiment = predict_sentiment(\n", + " \"Stock market was bearish today\", torch.load(safe_model_path)\n", + ")" ] }, { @@ -133,20 +138,20 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "No settings file detected at /Users/mehrinkiani/Documents/modelscan/notebooks/modelscan-settings.toml. Using defaults. \n", - "\n", - "Scanning /Users/mehrinkiani/Documents/modelscan/notebooks/PyTorchModels/safe_model.pt:safe_model/data.pkl using modelscan.scanners.PickleUnsafeOpScan model scan\n", - "\n", - "\u001b[34m--- Summary ---\u001b[0m\n", - "\n", - "\u001b[32m No issues found! 🎉\u001b[0m\n" + "No settings file detected at /Users/mehrinkiani/Documents/modelscan/notebooks/modelscan-settings.toml. Using defaults. \n", + "\n", + "Scanning /Users/mehrinkiani/Documents/modelscan/notebooks/PyTorchModels/safe_model.pt:safe_model/data.pkl using modelscan.scanners.PickleUnsafeOpScan model scan\n", + "\n", + "\u001b[34m--- Summary ---\u001b[0m\n", + "\n", + "\u001b[32m No issues found! 🎉\u001b[0m\n" ] } ], @@ -165,7 +170,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -175,7 +180,6 @@ "\n", "unsafe_model_path = os.path.join(model_directory, \"unsafe_model.pt\")\n", "\n", - "\n", "payload = get_payload(command, malicious_code)\n", "torch.save(\n", " torch.load(safe_model_path),\n", @@ -197,11 +201,11 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": {}, "outputs": [ { - "name": "stdout", + "name": "stderr", "output_type": "stream", "text": [ "aws_access_key_id=\n", @@ -211,7 +215,7 @@ } ], "source": [ - "sentiment = predict_sentiment(\"Stock market was bearish today\", torch.load(unsafe_model_path))" + "predict_sentiment(\"Stock market was bearish today\", torch.load(unsafe_model_path))" ] }, { @@ -227,41 +231,41 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "No settings file detected at /Users/mehrinkiani/Documents/modelscan/notebooks/modelscan-settings.toml. Using defaults. \n", - "\n", - "Scanning /Users/mehrinkiani/Documents/modelscan/notebooks/PyTorchModels/unsafe_model.pt:unsafe_model/data.pkl using modelscan.scanners.PickleUnsafeOpScan model scan\n", - "\n", - "\u001b[34m--- Summary ---\u001b[0m\n", - "\n", - "Total Issues: \u001b[1;36m1\u001b[0m\n", - "\n", - "Total Issues By Severity:\n", - "\n", - " - LOW: \u001b[1;32m0\u001b[0m\n", - " - MEDIUM: \u001b[1;32m0\u001b[0m\n", - " - HIGH: \u001b[1;32m0\u001b[0m\n", - " - CRITICAL: \u001b[1;36m1\u001b[0m\n", - "\n", - "\u001b[34m--- Issues by Severity ---\u001b[0m\n", - "\n", - "\u001b[34m--- CRITICAL ---\u001b[0m\n", - "\n", - "Unsafe operator found:\n", - " - Severity: CRITICAL\n", - " - Description: Use of unsafe operator 'system' from module 'posix'\n", - " - Source: /Users/mehrinkiani/Documents/modelscan/notebooks/PyTorchModels/unsafe_model.pt:unsafe_model/data.pkl\n" + "No settings file detected at /Users/mehrinkiani/Documents/modelscan/notebooks/modelscan-settings.toml. Using defaults. \n", + "\n", + "Scanning /Users/mehrinkiani/Documents/modelscan/notebooks/PyTorchModels/unsafe_model.pt:unsafe_model/data.pkl using modelscan.scanners.PickleUnsafeOpScan model scan\n", + "\n", + "\u001b[34m--- Summary ---\u001b[0m\n", + "\n", + "Total Issues: \u001b[1;36m1\u001b[0m\n", + "\n", + "Total Issues By Severity:\n", + "\n", + " - LOW: \u001b[1;32m0\u001b[0m\n", + " - MEDIUM: \u001b[1;32m0\u001b[0m\n", + " - HIGH: \u001b[1;32m0\u001b[0m\n", + " - CRITICAL: \u001b[1;36m1\u001b[0m\n", + "\n", + "\u001b[34m--- Issues by Severity ---\u001b[0m\n", + "\n", + "\u001b[34m--- CRITICAL ---\u001b[0m\n", + "\n", + "Unsafe operator found:\n", + " - Severity: CRITICAL\n", + " - Description: Use of unsafe operator 'system' from module 'posix'\n", + " - Source: /Users/mehrinkiani/Documents/modelscan/notebooks/PyTorchModels/unsafe_model.pt:unsafe_model/data.pkl\n" ] } ], "source": [ - "!modelscan --path ./PyTorchModels/unsafe_model.pt " + "!modelscan --path ./PyTorchModels/unsafe_model.pt" ] }, { @@ -283,28 +287,28 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "No settings file detected at /Users/mehrinkiani/Documents/modelscan/notebooks/modelscan-settings.toml. Using defaults. \n", - "\n", - "Scanning /Users/mehrinkiani/Documents/modelscan/notebooks/PyTorchModels/unsafe_model.pt:unsafe_model/data.pkl using modelscan.scanners.PickleUnsafeOpScan model scan\n", - "\u001b[1m{\u001b[0m\u001b[32m\"modelscan_version\"\u001b[0m: \u001b[32m\"0.5.0\"\u001b[0m, \u001b[32m\"timestamp\"\u001b[0m: \u001b[32m\"2024-01-25T17:10:54.306065\"\u001b[0m, \n", - "\u001b[32m\"input_path\"\u001b[0m: \n", - "\u001b[32m\"/Users/mehrinkiani/Documents/modelscan/notebooks/PyTorchModels/unsafe_model.pt\"\u001b[0m\n", - ", \u001b[32m\"total_issues\"\u001b[0m: \u001b[1;36m1\u001b[0m, \u001b[32m\"summary\"\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m\"total_issues_by_severity\"\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m\"LOW\"\u001b[0m: \u001b[1;36m0\u001b[0m, \n", - "\u001b[32m\"MEDIUM\"\u001b[0m: \u001b[1;36m0\u001b[0m, \u001b[32m\"HIGH\"\u001b[0m: \u001b[1;36m0\u001b[0m, \u001b[32m\"CRITICAL\"\u001b[0m: \u001b[1;36m1\u001b[0m\u001b[1m}\u001b[0m\u001b[1m}\u001b[0m, \u001b[32m\"issues_by_severity\"\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m\"CRITICAL\"\u001b[0m: \n", - "\u001b[1m[\u001b[0m\u001b[1m{\u001b[0m\u001b[32m\"description\"\u001b[0m: \u001b[32m\"Use of unsafe operator 'system' from module 'posix'\"\u001b[0m, \n", - "\u001b[32m\"operator\"\u001b[0m: \u001b[32m\"system\"\u001b[0m, \u001b[32m\"module\"\u001b[0m: \u001b[32m\"posix\"\u001b[0m, \u001b[32m\"source\"\u001b[0m: \n", - "\u001b[32m\"/Users/mehrinkiani/Documents/modelscan/notebooks/PyTorchModels/unsafe_model.pt:\u001b[0m\n", - "\u001b[32munsafe_model/data.pkl\"\u001b[0m, \u001b[32m\"scanner\"\u001b[0m: \u001b[32m\"modelscan.scanners.PickleUnsafeOpScan\"\u001b[0m\u001b[1m}\u001b[0m\u001b[1m]\u001b[0m\u001b[1m}\u001b[0m, \n", - "\u001b[32m\"errors\"\u001b[0m: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m, \u001b[32m\"scanned\"\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m\"total_scanned\"\u001b[0m: \u001b[1;36m1\u001b[0m, \u001b[32m\"scanned_files\"\u001b[0m: \n", - "\u001b[1m[\u001b[0m\u001b[32m\"/Users/mehrinkiani/Documents/modelscan/notebooks/PyTorchModels/unsafe_model.pt\u001b[0m\n", - "\u001b[32m:unsafe_model/data.pkl\"\u001b[0m\u001b[1m]\u001b[0m\u001b[1m}\u001b[0m\u001b[1m}\u001b[0m\n" + "No settings file detected at /Users/mehrinkiani/Documents/modelscan/notebooks/modelscan-settings.toml. Using defaults. \n", + "\n", + "Scanning /Users/mehrinkiani/Documents/modelscan/notebooks/PyTorchModels/unsafe_model.pt:unsafe_model/data.pkl using modelscan.scanners.PickleUnsafeOpScan model scan\n", + "\u001b[1m{\u001b[0m\u001b[32m\"modelscan_version\"\u001b[0m: \u001b[32m\"0.5.0\"\u001b[0m, \u001b[32m\"timestamp\"\u001b[0m: \u001b[32m\"2024-01-25T17:10:54.306065\"\u001b[0m, \n", + "\u001b[32m\"input_path\"\u001b[0m: \n", + "\u001b[32m\"/Users/mehrinkiani/Documents/modelscan/notebooks/PyTorchModels/unsafe_model.pt\"\u001b[0m\n", + ", \u001b[32m\"total_issues\"\u001b[0m: \u001b[1;36m1\u001b[0m, \u001b[32m\"summary\"\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m\"total_issues_by_severity\"\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m\"LOW\"\u001b[0m: \u001b[1;36m0\u001b[0m, \n", + "\u001b[32m\"MEDIUM\"\u001b[0m: \u001b[1;36m0\u001b[0m, \u001b[32m\"HIGH\"\u001b[0m: \u001b[1;36m0\u001b[0m, \u001b[32m\"CRITICAL\"\u001b[0m: \u001b[1;36m1\u001b[0m\u001b[1m}\u001b[0m\u001b[1m}\u001b[0m, \u001b[32m\"issues_by_severity\"\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m\"CRITICAL\"\u001b[0m: \n", + "\u001b[1m[\u001b[0m\u001b[1m{\u001b[0m\u001b[32m\"description\"\u001b[0m: \u001b[32m\"Use of unsafe operator 'system' from module 'posix'\"\u001b[0m, \n", + "\u001b[32m\"operator\"\u001b[0m: \u001b[32m\"system\"\u001b[0m, \u001b[32m\"module\"\u001b[0m: \u001b[32m\"posix\"\u001b[0m, \u001b[32m\"source\"\u001b[0m: \n", + "\u001b[32m\"/Users/mehrinkiani/Documents/modelscan/notebooks/PyTorchModels/unsafe_model.pt:\u001b[0m\n", + "\u001b[32munsafe_model/data.pkl\"\u001b[0m, \u001b[32m\"scanner\"\u001b[0m: \u001b[32m\"modelscan.scanners.PickleUnsafeOpScan\"\u001b[0m\u001b[1m}\u001b[0m\u001b[1m]\u001b[0m\u001b[1m}\u001b[0m, \n", + "\u001b[32m\"errors\"\u001b[0m: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m, \u001b[32m\"scanned\"\u001b[0m: \u001b[1m{\u001b[0m\u001b[32m\"total_scanned\"\u001b[0m: \u001b[1;36m1\u001b[0m, \u001b[32m\"scanned_files\"\u001b[0m: \n", + "\u001b[1m[\u001b[0m\u001b[32m\"/Users/mehrinkiani/Documents/modelscan/notebooks/PyTorchModels/unsafe_model.pt\u001b[0m\n", + "\u001b[32m:unsafe_model/data.pkl\"\u001b[0m\u001b[1m]\u001b[0m\u001b[1m}\u001b[0m\u001b[1m}\u001b[0m\n" ] } ], @@ -312,13 +316,6 @@ "# This will save the scan results in file: pytorch-model-scan-results.json\n", "!modelscan --path ./PyTorchModels/unsafe_model.pt -r json -o pytorch-model-scan-results.json" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -337,7 +334,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.10.14" }, "vscode": { "interpreter": { diff --git a/notebooks/utils/pickle_codeinjection.py b/notebooks/utils/pickle_codeinjection.py index f592922..a609c73 100644 --- a/notebooks/utils/pickle_codeinjection.py +++ b/notebooks/utils/pickle_codeinjection.py @@ -1,6 +1,8 @@ +from __future__ import annotations + +import os import pickle import struct -import os class PickleInject: @@ -88,7 +90,22 @@ def __reduce__(self): return self.command, (self.args, {}) -def get_payload(command: str, malicious_code: str): +def get_payload( + command: str, malicious_code: str +) -> PickleInject.System | PickleInject.Exec | PickleInject.Eval | PickleInject.RunPy: + """ + Get the payload based on the command and malicious code provided. + + Args: + command: The command to execute. + malicious_code: The malicious code to inject. + + Returns: + The payload object based on the command. + + Raises: + ValueError: If an invalid command is provided. + """ if command == "system": payload = PickleInject.System(malicious_code) elif command == "exec": @@ -97,6 +114,9 @@ def get_payload(command: str, malicious_code: str): payload = PickleInject.Eval(malicious_code) elif command == "runpy": payload = PickleInject.RunPy(malicious_code) + else: + raise ValueError("Invalid command provided.") + return payload diff --git a/notebooks/utils/pytorch_sentiment_model.py b/notebooks/utils/pytorch_sentiment_model.py index bf668a4..25278df 100644 --- a/notebooks/utils/pytorch_sentiment_model.py +++ b/notebooks/utils/pytorch_sentiment_model.py @@ -1,17 +1,27 @@ +from typing import Any, Final from transformers import AutoModelForSequenceClassification -from transformers import TFAutoModelForSequenceClassification from transformers import AutoTokenizer import numpy as np from scipy.special import softmax import csv import urllib.request -import os import torch +SENTIMENT_TASK: Final[str] = "sentiment" -# Preprocess text (username and link placeholders) -def preprocess(text): - new_text = [] + +def _preprocess(text: str) -> str: + """ + Preprocess the given text by replacing usernames starting with '@' with '@user' + and replacing URLs starting with 'http' with 'http'. + + Args: + text: The input text to be preprocessed. + + Returns: + The preprocessed text. + """ + new_text: list[str] = [] for t in text.split(" "): t = "@user" if t.startswith("@") and len(t) > 1 else t @@ -20,27 +30,37 @@ def preprocess(text): return " ".join(new_text) -def download_model(safe_model_path): - task = "sentiment" - MODEL = f"cardiffnlp/twitter-roberta-base-{task}" - # PT - model = AutoModelForSequenceClassification.from_pretrained(MODEL) +def download_model(safe_model_path: str) -> None: + """ + Download a pre-trained model and saves it to the specified path. + + Args: + safe_model_path: The path where the model will be saved. + """ + pretrained_model_name = f"cardiffnlp/twitter-roberta-base-{SENTIMENT_TASK}" + model = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name) torch.save(model, safe_model_path) -def predict_sentiment(text: str, model): - task = "sentiment" - MODEL = f"cardiffnlp/twitter-roberta-base-sentiment" - tokenizer = AutoTokenizer.from_pretrained(MODEL) +def predict_sentiment(text: str, model: Any) -> None: + """ + Predict the sentiment of a given text using a pre-trained sentiment analysis model. + + Args: + text: The input text to analyze. + model: The sentiment analysis model. + """ + pretrained_model_name = "cardiffnlp/twitter-roberta-base-sentiment" + tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name) - text = preprocess(text) + text = _preprocess(text) encoded_input = tokenizer(text, return_tensors="pt") output = model(**encoded_input) scores = output[0][0].detach().numpy() scores = softmax(scores) - labels = [] - mapping_link = f"https://raw.githubusercontent.com/cardiffnlp/tweeteval/main/datasets/{task}/mapping.txt" + labels: list[str] = [] + mapping_link = f"https://raw.githubusercontent.com/cardiffnlp/tweeteval/main/datasets/{SENTIMENT_TASK}/mapping.txt" with urllib.request.urlopen(mapping_link) as f: html = f.read().decode("utf-8").split("\n") csvreader = csv.reader(html, delimiter="\t") diff --git a/tests/test_modelscan.py b/tests/test_modelscan.py index 72b69bf..4355798 100644 --- a/tests/test_modelscan.py +++ b/tests/test_modelscan.py @@ -483,7 +483,7 @@ def test_scan_zip(zip_file_path: Any) -> None: ms = ModelScan() results = ms.scan(f"{zip_file_path}/test.zip") - assert results["summary"]["scanned"]["scanned_files"] == [f"test.zip:data.pkl"] + assert results["summary"]["scanned"]["scanned_files"] == ["test.zip:data.pkl"] assert results["summary"]["skipped"]["skipped_files"] == [] assert ms.issues.all_issues == expected @@ -495,15 +495,15 @@ def test_scan_pytorch(pytorch_file_path: Any) -> None: assert results["summary"]["skipped"]["skipped_files"] == [ { "category": SkipCategories.MAGIC_NUMBER.name, - "description": f"Invalid magic number", - "source": f"bad_pytorch.pt", + "description": "Invalid magic number", + "source": "bad_pytorch.pt", } ] assert ms.issues.all_issues == [] results = ms.scan(Path(f"{pytorch_file_path}/safe_zip_pytorch.pt")) assert results["summary"]["scanned"]["scanned_files"] == [ - f"safe_zip_pytorch.pt:safe_zip_pytorch/data.pkl" + "safe_zip_pytorch.pt:safe_zip_pytorch/data.pkl" ] assert set( @@ -521,7 +521,7 @@ def test_scan_pytorch(pytorch_file_path: Any) -> None: results = ms.scan(Path(f"{pytorch_file_path}/safe_old_format_pytorch.pt")) assert results["summary"]["scanned"]["scanned_files"] == [ - f"safe_old_format_pytorch.pt" + "safe_old_format_pytorch.pt" ] assert results["summary"]["skipped"]["skipped_files"] == [] assert ms.issues.all_issues == [] @@ -542,7 +542,7 @@ def test_scan_pytorch(pytorch_file_path: Any) -> None: ] results = ms.scan(unsafe_zip_path) assert results["summary"]["scanned"]["scanned_files"] == [ - f"unsafe_zip_pytorch.pt:unsafe_zip_pytorch/data.pkl", + "unsafe_zip_pytorch.pt:unsafe_zip_pytorch/data.pkl", ] assert set( [ @@ -562,7 +562,7 @@ def test_scan_numpy(numpy_file_path: Any) -> None: ms = ModelScan() results = ms.scan(f"{numpy_file_path}/safe_numpy.npy") assert ms.issues.all_issues == [] - assert results["summary"]["scanned"]["scanned_files"] == [f"safe_numpy.npy"] + assert results["summary"]["scanned"]["scanned_files"] == ["safe_numpy.npy"] assert results["summary"]["skipped"]["skipped_files"] == [] assert results["errors"] == [] @@ -581,7 +581,7 @@ def test_scan_numpy(numpy_file_path: Any) -> None: results = ms.scan(f"{numpy_file_path}/unsafe_numpy.npy") compare_results(ms.issues.all_issues, expected) - assert results["summary"]["scanned"]["scanned_files"] == [f"unsafe_numpy.npy"] + assert results["summary"]["scanned"]["scanned_files"] == ["unsafe_numpy.npy"] assert results["summary"]["skipped"]["skipped_files"] == [] assert results["errors"] == [] @@ -590,14 +590,14 @@ def test_scan_file_path(file_path: Any) -> None: benign_pickle = ModelScan() results = benign_pickle.scan(Path(f"{file_path}/data/benign0_v3.pkl")) assert benign_pickle.issues.all_issues == [] - assert results["summary"]["scanned"]["scanned_files"] == [f"benign0_v3.pkl"] + assert results["summary"]["scanned"]["scanned_files"] == ["benign0_v3.pkl"] assert results["summary"]["skipped"]["skipped_files"] == [] assert results["errors"] == [] benign_dill = ModelScan() results = benign_dill.scan(Path(f"{file_path}/data/benign0_v3.dill")) assert benign_dill.issues.all_issues == [] - assert results["summary"]["scanned"]["scanned_files"] == [f"benign0_v3.dill"] + assert results["summary"]["scanned"]["scanned_files"] == ["benign0_v3.dill"] assert results["summary"]["skipped"]["skipped_files"] == [] assert results["errors"] == [] @@ -646,7 +646,7 @@ def test_scan_file_path(file_path: Any) -> None: } results = malicious0.scan(Path(f"{file_path}/data/malicious0.pkl")) compare_results(malicious0.issues.all_issues, expected_malicious0) - assert results["summary"]["scanned"]["scanned_files"] == [f"malicious0.pkl"] + assert results["summary"]["scanned"]["scanned_files"] == ["malicious0.pkl"] assert results["summary"]["skipped"]["skipped_files"] == [] assert results["errors"] == [] @@ -1268,35 +1268,35 @@ def test_scan_directory_path(file_path: str) -> None: results = ms.scan(p) compare_results(ms.issues.all_issues, expected) assert set(results["summary"]["scanned"]["scanned_files"]) == { - f"malicious1.zip:data.pkl", - f"malicious0.pkl", - f"malicious3.pkl", - f"malicious6.pkl", - f"malicious7.pkl", - f"malicious8.pkl", - f"malicious9.pkl", - f"malicious10.pkl", - f"malicious11.pkl", - f"malicious12.pkl", - f"malicious13.pkl", - f"malicious14.pkl", - f"malicious1_v0.dill", - f"malicious1_v3.dill", - f"malicious1_v4.dill", - f"malicious4.pickle", - f"malicious5.pickle", - f"malicious1_v0.pkl", - f"malicious1_v3.pkl", - f"malicious1_v4.pkl", - f"malicious2_v0.pkl", - f"malicious2_v3.pkl", - f"malicious2_v4.pkl", - f"benign0_v0.pkl", - f"benign0_v3.pkl", - f"benign0_v4.pkl", - f"benign0_v0.dill", - f"benign0_v3.dill", - f"benign0_v4.dill", + "malicious1.zip:data.pkl", + "malicious0.pkl", + "malicious3.pkl", + "malicious6.pkl", + "malicious7.pkl", + "malicious8.pkl", + "malicious9.pkl", + "malicious10.pkl", + "malicious11.pkl", + "malicious12.pkl", + "malicious13.pkl", + "malicious14.pkl", + "malicious1_v0.dill", + "malicious1_v3.dill", + "malicious1_v4.dill", + "malicious4.pickle", + "malicious5.pickle", + "malicious1_v0.pkl", + "malicious1_v3.pkl", + "malicious1_v4.pkl", + "malicious2_v0.pkl", + "malicious2_v3.pkl", + "malicious2_v4.pkl", + "benign0_v0.pkl", + "benign0_v3.pkl", + "benign0_v4.pkl", + "benign0_v0.dill", + "benign0_v3.dill", + "benign0_v4.dill", } assert results["summary"]["skipped"]["skipped_files"] == [] assert results["errors"] == [] @@ -1325,9 +1325,9 @@ def test_scan_keras(keras_file_path: Any, file_extension: str) -> None: if file_extension == ".pb": assert set(results["summary"]["scanned"]["scanned_files"]) == { - f"fingerprint.pb", - f"keras_metadata.pb", - f"saved_model.pb", + "fingerprint.pb", + "keras_metadata.pb", + "saved_model.pb", } assert set( @@ -1336,8 +1336,8 @@ def test_scan_keras(keras_file_path: Any, file_extension: str) -> None: for skipped_file in results["summary"]["skipped"]["skipped_files"] ] ) == { - f"variables/variables.data-00000-of-00001", - f"variables/variables.index", + "variables/variables.data-00000-of-00001", + "variables/variables.index", } assert results["errors"] == [] @@ -1420,9 +1420,9 @@ def test_scan_keras(keras_file_path: Any, file_extension: str) -> None: assert ms.issues.all_issues == expected assert results["errors"] == [] assert set(results["summary"]["scanned"]["scanned_files"]) == { - f"fingerprint.pb", - f"keras_metadata.pb", - f"saved_model.pb", + "fingerprint.pb", + "keras_metadata.pb", + "saved_model.pb", } assert set( [ @@ -1430,8 +1430,8 @@ def test_scan_keras(keras_file_path: Any, file_extension: str) -> None: for skipped_file in results["summary"]["skipped"]["skipped_files"] ] ) == { - f"variables/variables.data-00000-of-00001", - f"variables/variables.index", + "variables/variables.data-00000-of-00001", + "variables/variables.index", } else: unsafe_filename = f"{keras_file_path_parent_dir}/unsafe{file_extension}" @@ -1484,9 +1484,9 @@ def test_scan_tensorflow(tensorflow_file_path: Any) -> None: results = ms.scan(Path(f"{safe_tensorflow_model_dir}")) assert ms.issues.all_issues == [] assert set(results["summary"]["scanned"]["scanned_files"]) == { - f"fingerprint.pb", - f"keras_metadata.pb", - f"saved_model.pb", + "fingerprint.pb", + "keras_metadata.pb", + "saved_model.pb", } assert set( [ @@ -1494,8 +1494,8 @@ def test_scan_tensorflow(tensorflow_file_path: Any) -> None: for skipped_file in results["summary"]["skipped"]["skipped_files"] ] ) == { - f"variables/variables.data-00000-of-00001", - f"variables/variables.index", + "variables/variables.data-00000-of-00001", + "variables/variables.index", } assert results["errors"] == [] @@ -1526,9 +1526,9 @@ def test_scan_tensorflow(tensorflow_file_path: Any) -> None: assert ms.issues.all_issues == expected assert set(results["summary"]["scanned"]["scanned_files"]) == { - f"fingerprint.pb", - f"keras_metadata.pb", - f"saved_model.pb", + "fingerprint.pb", + "keras_metadata.pb", + "saved_model.pb", } assert set( [ @@ -1536,8 +1536,8 @@ def test_scan_tensorflow(tensorflow_file_path: Any) -> None: for skipped_file in results["summary"]["skipped"]["skipped_files"] ] ) == { - f"variables/variables.data-00000-of-00001", - f"variables/variables.index", + "variables/variables.data-00000-of-00001", + "variables/variables.index", } assert results["errors"] == [] diff --git a/tests/test_utils.py b/tests/test_utils.py index 6fedb5e..52a195c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,7 +3,6 @@ import pickle import struct from typing import Any, Tuple -import os import torch import torch.nn as nn import tensorflow as tf