diff --git a/modelscan/cli.py b/modelscan/cli.py index bd61d51..8f102d0 100644 --- a/modelscan/cli.py +++ b/modelscan/cli.py @@ -5,7 +5,7 @@ import click -from modelscan.modelscan import Modelscan +from modelscan.modelscan import ModelScan from modelscan.reports import ConsoleReport from modelscan._version import __version__ @@ -61,14 +61,13 @@ def cli( if log is not None: logger.setLevel(getattr(logging, log)) - modelscan = Modelscan() + modelscan = ModelScan() if path is not None: pathlibPath = Path().cwd() if path == "." else Path(path).absolute() if not pathlibPath.exists(): raise FileNotFoundError(f"Path {path} does not exist") else: - modelscan.scan_path(pathlibPath) - + modelscan.scan(pathlibPath) else: raise click.UsageError("Command line must include a path") ConsoleReport.generate( diff --git a/modelscan/issues.py b/modelscan/issues.py index 27ee1c2..e9d64d1 100644 --- a/modelscan/issues.py +++ b/modelscan/issues.py @@ -21,9 +21,16 @@ class IssueCode(Enum): class IssueDetails(metaclass=abc.ABCMeta): + def __init__(self, scanner: str = "") -> None: + self.scanner = scanner + @abc.abstractmethod def output_lines(self) -> List[str]: - raise NotImplemented + raise NotImplementedError + + @abc.abstractmethod + def output_json(self) -> Dict[str, str]: + raise NotImplementedError class Issue: @@ -110,13 +117,25 @@ def group_by_severity(self) -> Dict[str, List[Issue]]: class OperatorIssueDetails(IssueDetails): - def __init__(self, module: str, operator: str, source: Union[Path, str]) -> None: + def __init__( + self, module: str, operator: str, source: Union[Path, str], scanner: str = "" + ) -> None: self.module = module self.operator = operator self.source = source + self.scanner = scanner def output_lines(self) -> List[str]: return [ f"Description: Use of unsafe operator '{self.operator}' from module '{self.module}'", f"Source: {str(self.source)}", ] + + def output_json(self) -> Dict[str, str]: + return { + "description": f"Use of unsafe operator '{self.operator}' from module '{self.module}'", + "operator": f"{self.operator}", + "module": f"{self.module}", + "source": f"{str(self.source)}", + "scanner": f"{self.scanner}", + } diff --git a/modelscan/models/__init__.py b/modelscan/models/__init__.py deleted file mode 100644 index c0021d8..0000000 --- a/modelscan/models/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from modelscan.models.h5.scan import H5Scan -from modelscan.models.pickle.scan import ( - PickleScan, - NumpyScan, - PyTorchScan, -) -from modelscan.models.saved_model.scan import SavedModelScan -from modelscan.models.keras.scan import KerasScan diff --git a/modelscan/models/pickle/scan.py b/modelscan/models/pickle/scan.py deleted file mode 100644 index de8579b..0000000 --- a/modelscan/models/pickle/scan.py +++ /dev/null @@ -1,77 +0,0 @@ -import logging -from pathlib import Path -from typing import IO, List, Tuple, Union, Optional - -from modelscan.error import Error -from modelscan.issues import Issue -from modelscan.models.scan import ScanBase -from modelscan.tools.picklescanner import ( - scan_numpy, - scan_pickle_bytes, - scan_pytorch, -) - -logger = logging.getLogger("modelscan") - - -class PyTorchScan(ScanBase): - @staticmethod - def scan( - source: Union[str, Path], - data: Optional[IO[bytes]] = None, - ) -> Tuple[List[Issue], List[Error]]: - if data: - return scan_pytorch(data=data, source=source) - - with open(source, "rb") as file_io: - return scan_pytorch(data=file_io, source=source) - - @staticmethod - def supported_extensions() -> List[str]: - return [".bin", ".pt", ".pth", ".ckpt"] - - @staticmethod - def name() -> str: - return "pytorch" - - -class NumpyScan(ScanBase): - @staticmethod - def scan( - source: Union[str, Path], - data: Optional[IO[bytes]] = None, - ) -> Tuple[List[Issue], List[Error]]: - if data: - return scan_numpy(data=data, source=source) - - with open(source, "rb") as file_io: - return scan_numpy(data=file_io, source=source) - - @staticmethod - def supported_extensions() -> List[str]: - return [".npy"] - - @staticmethod - def name() -> str: - return "numpy" - - -class PickleScan(ScanBase): - @staticmethod - def scan( - source: Union[str, Path], - data: Optional[IO[bytes]] = None, - ) -> Tuple[List[Issue], List[Error]]: - if data: - return scan_pickle_bytes(data=data, source=source) - - with open(source, "rb") as file_io: - return scan_pickle_bytes(data=file_io, source=source) - - @staticmethod - def supported_extensions() -> List[str]: - return [".pkl", ".pickle", ".joblib", ".dill", ".dat", ".data"] - - @staticmethod - def name() -> str: - return "pickle" diff --git a/modelscan/models/scan.py b/modelscan/models/scan.py deleted file mode 100644 index 2eade34..0000000 --- a/modelscan/models/scan.py +++ /dev/null @@ -1,25 +0,0 @@ -import abc -from pathlib import Path -from typing import List, Tuple, Union, Optional, IO - -from modelscan.error import Error -from modelscan.issues import Issue - - -class ScanBase(metaclass=abc.ABCMeta): - @staticmethod - @abc.abstractmethod - def name() -> str: - raise NotImplementedError - - @staticmethod - @abc.abstractmethod - def scan( - source: Union[str, Path], data: Optional[IO[bytes]] = None - ) -> Tuple[List[Issue], List[Error]]: - raise NotImplementedError - - @staticmethod - @abc.abstractmethod - def supported_extensions() -> List[str]: - raise NotImplementedError diff --git a/modelscan/modelscan.py b/modelscan/modelscan.py index 9da9a9f..f459119 100644 --- a/modelscan/modelscan.py +++ b/modelscan/modelscan.py @@ -1,88 +1,131 @@ -import io -import json import logging -import os import zipfile -import inspect +import importlib + +from modelscan.settings import DEFAULT_SCANNERS, DEFAULT_SETTINGS from pathlib import Path -from typing import List, Union, Optional, IO +from typing import List, Union, Optional, IO, Dict, Tuple, Any +from datetime import datetime -from modelscan.error import Error -from modelscan.issues import Issues, Issue -from modelscan import models -from modelscan.models.keras.scan import KerasScan -from modelscan.models.scan import ScanBase +from modelscan.error import Error, ModelScanError +from modelscan.issues import Issues, IssueSeverity +from modelscan.scanners.scan import ScanBase from modelscan.tools.utils import _is_zipfile - +from modelscan._version import __version__ logger = logging.getLogger("modelscan") -class Modelscan: - def __init__(self) -> None: - # Scans - - self.supported_model_scans = [ - member - for _, member in inspect.getmembers(models) - if inspect.isclass(member) - and issubclass(member, ScanBase) - and not inspect.isabstract(member) - ] - - self.supported_extensions = set() - for scan in self.supported_model_scans: - self.supported_extensions.update(scan.supported_extensions()) - self.supported_zip_extensions = set([".zip", ".npz"]) - logger.debug(f"Supported model files {self.supported_extensions}") - logger.debug(f"Supported zip model files {self.supported_zip_extensions}") - +class ModelScan: + def __init__( + self, + scanners_to_load: List[str] = DEFAULT_SCANNERS, + settings: Dict[str, Any] = DEFAULT_SETTINGS, + ) -> None: # Output self._issues = Issues() self._errors: List[Error] = [] self._skipped: List[str] = [] self._scanned: List[str] = [] + self._input_path: str = "" + + # Scanners + self._scanners_to_run: List[ScanBase] = [] + self._settings: Dict[str, Any] = settings + self._load_scanners(scanners_to_load) + + def _load_scanners(self, scanners_to_load: List[str]) -> None: + scanner_classes: Dict[str, ScanBase] = {} + for scanner_path in scanners_to_load: + try: + (modulename, classname) = scanner_path.rsplit(".", 1) + imported_module = importlib.import_module( + name=modulename, package=classname + ) + scanner_class = getattr(imported_module, classname) + scanner_classes[scanner_path] = scanner_class + except Exception as e: + logger.error(f"Error importing scanner {scanner_path}") + self._errors.append( + ModelScanError( + scanner_path, f"Error importing scanner {scanner_path}: {e}" + ) + ) + + scanners_to_run: List[ScanBase] = [] + for scanner_class, scanner in scanner_classes.items(): + is_enabled: bool = self._settings["scanners"][scanner_class]["enabled"] + if is_enabled: + dep_error = scanner.handle_binary_dependencies() + if dep_error: + logger.info( + f"Skipping {scanner.full_name()} as it is missing dependencies" + ) + self._errors.append(dep_error) + else: + scanners_to_run.append(scanner) + self._scanners_to_run = scanners_to_run + + def scan( + self, + path: Union[str, Path], + ) -> Dict[str, Any]: + self._issues = Issues() + self._errors = [] + self._skipped = [] + self._scanned = [] + self._input_path = str(path) - def scan_path(self, path: Path) -> None: - if path.is_dir(): - self._scan_directory(path) - elif _is_zipfile(path) or path.suffix in self.supported_zip_extensions: - is_keras_file = path.suffix in KerasScan.supported_extensions() - if is_keras_file: - self._scan_source(source=path, extension=path.suffix) - else: + self._scan_path(Path(path)) + return self._generate_results() + + def _scan_path( + self, + path: Path, + ) -> None: + if Path.exists(path): + scanned = self._scan_source(path) + if not scanned and path.is_dir(): + self._scan_directory(path) + elif ( + _is_zipfile(path) + or path.suffix in self._settings["supported_zip_extensions"] + ): self._scan_zip(path) + elif not scanned: + self._skipped.append(str(path)) else: - self._scan_source(source=path, extension=path.suffix) + logger.error(f"Error: path {path} is not valid") + self._errors.append( + ModelScanError("ModelScan", f"Path {path} is not valid") + ) + self._skipped.append(str(path)) def _scan_directory(self, directory_path: Path) -> None: for path in directory_path.rglob("*"): if not path.is_dir(): - self.scan_path(path) + self._scan_path(path) def _scan_source( self, source: Union[str, Path], - extension: str, data: Optional[IO[bytes]] = None, - ) -> None: - issues: List[Issue] = [] - errors: List[Error] = [] - - if extension not in self.supported_extensions: - logger.debug(f"Skipping file {source}") - self._skipped.append(str(source)) - return - - for scan in self.supported_model_scans: - if extension in scan.supported_extensions(): - logger.info(f"Scanning {source} using {scan.name()} model scan") - issues, errors = scan.scan(source=source, data=data) + ) -> bool: + scanned = False + for scan_class in self._scanners_to_run: + scanner = scan_class(self._settings["scanners"]) # type: ignore[operator] + scan_results = scanner.scan( + source=source, + data=data, + ) + if scan_results is not None: + logger.info(f"Scanning {source} using {scanner.full_name()} model scan") self._scanned.append(str(source)) - - self._issues.add_issues(issues) - self._errors.extend(errors) + self._issues.add_issues(scan_results.issues) + self._errors.extend(scan_results.errors) + scanned = True + return scanned def _scan_zip( self, source: Union[str, Path], data: Optional[IO[bytes]] = None @@ -91,17 +134,49 @@ def _scan_zip( with zipfile.ZipFile(data or source, "r") as zip: file_names = zip.namelist() for file_name in file_names: - file_ext = os.path.splitext(file_name)[1] with zip.open(file_name, "r") as file_io: self._scan_source( source=f"{source}:{file_name}", - extension=file_ext, data=file_io, ) except zipfile.BadZipFile as e: logger.debug(f"Skipping zip file {source}, due to error", e, exc_info=True) self._skipped.append(str(source)) + def _generate_results(self) -> Dict[str, Any]: + report: Dict[str, Any] = {} + + issues_by_severity = self._issues.group_by_severity() + total_issue_count = len(self._issues.all_issues) + + report["modelscan_version"] = __version__ + report["timestamp"] = datetime.now().isoformat() + report["input_path"] = self._input_path + report["total_issues"] = total_issue_count + report["summary"] = {"total_issues_by_severity": {}} + for severity in IssueSeverity: + if severity.name in issues_by_severity: + report["summary"]["total_issues_by_severity"][severity.name] = len( + issues_by_severity[severity.name] + ) + else: + report["summary"]["total_issues_by_severity"][severity.name] = 0 + + report["issues_by_severity"] = {} + for issue_key in issues_by_severity.keys(): + report["issues_by_severity"][issue_key] = [ + issue.details.output_json() for issue in issues_by_severity[issue_key] + ] + + report["errors"] = [str(error) for index, error in enumerate(self._errors)] + + report["skipped"] = {"total_skipped": len(self._skipped)} + report["skipped"]["skipped_files"] = [ + str(file_name) for file_name in self._skipped + ] + + return report + @property def issues(self) -> Issues: return self._issues diff --git a/modelscan/scanners/__init__.py b/modelscan/scanners/__init__.py new file mode 100644 index 0000000..2a45017 --- /dev/null +++ b/modelscan/scanners/__init__.py @@ -0,0 +1,8 @@ +from modelscan.scanners.h5.scan import H5Scan +from modelscan.scanners.pickle.scan import ( + PickleScan, + NumpyScan, + PyTorchScan, +) +from modelscan.scanners.saved_model.scan import SavedModelScan +from modelscan.scanners.keras.scan import KerasScan diff --git a/modelscan/models/h5/__init__.py b/modelscan/scanners/h5/__init__.py similarity index 100% rename from modelscan/models/h5/__init__.py rename to modelscan/scanners/h5/__init__.py diff --git a/modelscan/models/h5/scan.py b/modelscan/scanners/h5/scan.py similarity index 57% rename from modelscan/models/h5/scan.py rename to modelscan/scanners/h5/scan.py index 02810bd..ceac067 100644 --- a/modelscan/models/h5/scan.py +++ b/modelscan/scanners/h5/scan.py @@ -1,7 +1,7 @@ import json import logging from pathlib import Path -from typing import IO, List, Tuple, Union, Optional +from typing import IO, List, Union, Optional, Dict, Any try: import h5py @@ -10,49 +10,46 @@ except ImportError: h5py_installed = False -from modelscan.error import Error, ModelScanError -from modelscan.issues import Issue -from modelscan.models.saved_model.scan import SavedModelScan +from modelscan.error import ModelScanError +from modelscan.scanners.scan import ScanResults +from modelscan.scanners.saved_model.scan import SavedModelScan logger = logging.getLogger("modelscan") class H5Scan(SavedModelScan): - @staticmethod def scan( + self, source: Union[str, Path], data: Optional[IO[bytes]] = None, - ) -> Tuple[List[Issue], List[Error]]: - if not h5py_installed: - return [], [ - ModelScanError( - SavedModelScan.name(), - f"File: {source} \nTo scan an h5py file, please install modelscan with h5py extras. 'pip install \"modelscan\[h5py]\"' if you are using pip.", - ) - ] + ) -> Optional[ScanResults]: + if ( + not Path(source).suffix + in self._settings[H5Scan.full_name()]["supported_extensions"] + ): + return None if data: logger.warning( "H5 scanner got data bytes. It only support direct file scanning." ) + return None - return H5Scan._scan_keras_h5_file(source) + return self.label_results(self._scan_keras_h5_file(source)) - @staticmethod - def _scan_keras_h5_file( - source: Union[str, Path] - ) -> Tuple[List[Issue], List[Error]]: + def _scan_keras_h5_file(self, source: Union[str, Path]) -> ScanResults: machine_learning_library_name = "Keras" - operators_in_model = H5Scan._get_keras_h5_operator_names(source) + operators_in_model = self._get_keras_h5_operator_names(source) return H5Scan._check_for_unsafe_tf_keras_operator( module_name=machine_learning_library_name, raw_operator=operators_in_model, source=source, + settings=self._settings, ) - @staticmethod - def _get_keras_h5_operator_names(source: Union[str, Path]) -> List[str]: + def _get_keras_h5_operator_names(self, source: Union[str, Path]) -> List[str]: # Todo: source isn't guaranteed to be a file + with h5py.File(source, "r") as model_hdf5: try: model_config = json.loads(model_hdf5.attrs.get("model_config", {})) @@ -72,10 +69,21 @@ def _get_keras_h5_operator_names(source: Union[str, Path]) -> List[str]: return [] - @staticmethod - def supported_extensions() -> List[str]: - return [".h5"] - @staticmethod def name() -> str: return "hdf5" + + @staticmethod + def full_name() -> str: + return "modelscan.scanners.H5Scan" + + @staticmethod + def handle_binary_dependencies( + settings: Optional[Dict[str, Any]] = None + ) -> Optional[ModelScanError]: + if not h5py_installed: + return ModelScanError( + SavedModelScan.name(), + f"To use {H5Scan.full_name()}, please install modelscan with h5py extras. 'pip install \"modelscan\[h5py]\"' if you are using pip.", + ) + return None diff --git a/modelscan/models/keras/__init__.py b/modelscan/scanners/keras/__init__.py similarity index 100% rename from modelscan/models/keras/__init__.py rename to modelscan/scanners/keras/__init__.py diff --git a/modelscan/models/keras/scan.py b/modelscan/scanners/keras/scan.py similarity index 53% rename from modelscan/models/keras/scan.py rename to modelscan/scanners/keras/scan.py index 73929c5..8c13002 100644 --- a/modelscan/models/keras/scan.py +++ b/modelscan/scanners/keras/scan.py @@ -2,61 +2,77 @@ import zipfile import logging from pathlib import Path -from typing import IO, List, Tuple, Union, Optional +from typing import IO, List, Union, Optional + + +from modelscan.error import ModelScanError +from modelscan.scanners.scan import ScanResults +from modelscan.scanners.saved_model.scan import SavedModelScan -from modelscan.error import Error, ModelScanError -from modelscan.issues import Issue -from modelscan.models.saved_model.scan import SavedModelScan logger = logging.getLogger("modelscan") class KerasScan(SavedModelScan): - @staticmethod def scan( + self, source: Union[str, Path], data: Optional[IO[bytes]] = None, - ) -> Tuple[List[Issue], List[Error]]: + ) -> Optional[ScanResults]: + if ( + not Path(source).suffix + in self._settings[KerasScan.full_name()]["supported_extensions"] + ): + return None + try: with zipfile.ZipFile(data or source, "r") as zip: file_names = zip.namelist() for file_name in file_names: if file_name == "config.json": with zip.open(file_name, "r") as config_file: - return KerasScan._scan_keras_config_file( - source=f"{source}:{file_name}", config_file=config_file + return self.label_results( + self._scan_keras_config_file( + source=f"{source}:{file_name}", + config_file=config_file, + ) ) except zipfile.BadZipFile as e: - return [], [ + return ScanResults( + [], + [ + ModelScanError( + KerasScan.name(), + f"Skipping zip file {source}, due to error: {e}", + ) + ], + ) + + # Added return to pass the failing mypy test: Missing return statement + return ScanResults( + [], + [ ModelScanError( KerasScan.name(), - f"Skipping zip file {source}, due to error: {e}", + f"Unable to scan .keras file", # Not sure if this is a representative message for ModelScanError ) - ] - - # Added return to pass the failing mypy test: Missing return statement - return [], [ - ModelScanError( - KerasScan.name(), - f"Unable to scan .keras file", # Not sure if this is a representative message for ModelScanError - ) - ] + ], + ) - @staticmethod def _scan_keras_config_file( - source: Union[str, Path], config_file: IO[bytes] - ) -> Tuple[List[Issue], List[Error]]: + self, source: Union[str, Path], config_file: IO[bytes] + ) -> ScanResults: machine_learning_library_name = "Keras" - operators_in_model = KerasScan._get_keras_operator_names(source, config_file) + operators_in_model = self._get_keras_operator_names(source, config_file) return KerasScan._check_for_unsafe_tf_keras_operator( module_name=machine_learning_library_name, raw_operator=operators_in_model, source=source, + settings=self._settings, ) - @staticmethod def _get_keras_operator_names( - source: Union[str, Path], data: IO[bytes] + self, source: Union[str, Path], data: IO[bytes] ) -> List[str]: try: model_config_data = json.load(data) @@ -65,6 +81,7 @@ def _get_keras_operator_names( for layer in model_config_data.get("config", {}).get("layers", {}) if layer.get("class_name", {}) == "Lambda" ] + if lambda_layers: return ["Lambda"] * len(lambda_layers) @@ -75,9 +92,9 @@ def _get_keras_operator_names( return [] @staticmethod - def supported_extensions() -> List[str]: - return [".keras"] + def name() -> str: + return "keras" @staticmethod - def name() -> str: - return ".keras" + def full_name() -> str: + return "modelscan.scanners.KerasScan" diff --git a/modelscan/models/pickle/__init__.py b/modelscan/scanners/pickle/__init__.py similarity index 100% rename from modelscan/models/pickle/__init__.py rename to modelscan/scanners/pickle/__init__.py diff --git a/modelscan/scanners/pickle/scan.py b/modelscan/scanners/pickle/scan.py new file mode 100644 index 0000000..7ac1efd --- /dev/null +++ b/modelscan/scanners/pickle/scan.py @@ -0,0 +1,107 @@ +import logging +from pathlib import Path +from typing import IO, Union, Optional + +from modelscan.scanners.scan import ScanBase, ScanResults +from modelscan.tools.picklescanner import ( + scan_numpy, + scan_pickle_bytes, + scan_pytorch, +) + +logger = logging.getLogger("modelscan") + + +class PyTorchScan(ScanBase): + def scan( + self, + source: Union[str, Path], + data: Optional[IO[bytes]] = None, + ) -> Optional[ScanResults]: + if ( + not Path(source).suffix + in self._settings[PyTorchScan.full_name()]["supported_extensions"] + ): + return None + + if data: + results = scan_pytorch(data=data, source=source, settings=self._settings) + + else: + with open(source, "rb") as file_io: + results = scan_pytorch( + data=file_io, source=source, settings=self._settings + ) + + return self.label_results(results) + + @staticmethod + def name() -> str: + return "pytorch" + + @staticmethod + def full_name() -> str: + return "modelscan.scanners.PyTorchScan" + + +class NumpyScan(ScanBase): + def scan( + self, + source: Union[str, Path], + data: Optional[IO[bytes]] = None, + ) -> Optional[ScanResults]: + if ( + not Path(source).suffix + in self._settings[NumpyScan.full_name()]["supported_extensions"] + ): + return None + + if data: + results = scan_numpy(data=data, source=source, settings=self._settings) + + with open(source, "rb") as file_io: + results = scan_numpy(data=file_io, source=source, settings=self._settings) + + return self.label_results(results) + + @staticmethod + def name() -> str: + return "numpy" + + @staticmethod + def full_name() -> str: + return "modelscan.scanners.NumpyScan" + + +class PickleScan(ScanBase): + def scan( + self, + source: Union[str, Path], + data: Optional[IO[bytes]] = None, + ) -> Optional[ScanResults]: + if ( + not Path(source).suffix + in self._settings[PickleScan.full_name()]["supported_extensions"] + ): + return None + + if data: + results = scan_pickle_bytes( + data=data, source=source, settings=self._settings + ) + + else: + with open(source, "rb") as file_io: + results = scan_pickle_bytes( + data=file_io, source=source, settings=self._settings + ) + + return self.label_results(results) + + @staticmethod + def name() -> str: + return "pickle" + + @staticmethod + def full_name() -> str: + return "modelscan.scanners.PickleScan" diff --git a/modelscan/models/saved_model/__init__.py b/modelscan/scanners/saved_model/__init__.py similarity index 100% rename from modelscan/models/saved_model/__init__.py rename to modelscan/scanners/saved_model/__init__.py diff --git a/modelscan/models/saved_model/scan.py b/modelscan/scanners/saved_model/scan.py similarity index 68% rename from modelscan/models/saved_model/scan.py rename to modelscan/scanners/saved_model/scan.py index 306a7f5..0894b38 100644 --- a/modelscan/models/saved_model/scan.py +++ b/modelscan/scanners/saved_model/scan.py @@ -4,8 +4,7 @@ import logging from pathlib import Path -from typing import IO, List, Set, Tuple, Union, Optional, Dict - +from typing import IO, List, Set, Union, Optional, Dict, Any try: import tensorflow @@ -17,53 +16,49 @@ tensorflow_installed = False -from modelscan.error import Error, ModelScanError +from modelscan.error import ModelScanError from modelscan.issues import Issue, IssueCode, IssueSeverity, OperatorIssueDetails -from modelscan.models.scan import ScanBase +from modelscan.scanners.scan import ScanBase, ScanResults logger = logging.getLogger("modelscan") class SavedModelScan(ScanBase): - @staticmethod def scan( + self, source: Union[str, Path], data: Optional[IO[bytes]] = None, - ) -> Tuple[List[Issue], List[Error]]: - if not tensorflow_installed: - return [], [ - ModelScanError( - SavedModelScan.name(), - f"File: {source} \nTo scan an tensorflow file, please install modelscan with tensorflow extras. 'pip install \"modelscan\[tensorflow]\"' if you are using pip.", - ) - ] + ) -> Optional[ScanResults]: + if ( + not Path(source).suffix + in self._settings[SavedModelScan.full_name()]["supported_extensions"] + ): + return None if data: - return SavedModelScan._scan(source, data) + results = self._scan(source, data) - with open(source, "rb") as file_io: - return SavedModelScan._scan(source, data=file_io) + else: + with open(source, "rb") as file_io: + results = self._scan(source, data=file_io) - @staticmethod - def _scan( - source: Union[str, Path], data: IO[bytes] - ) -> Tuple[List[Issue], List[Error]]: + return self.label_results(results) + + def _scan(self, source: Union[str, Path], data: IO[bytes]) -> ScanResults: file_name = str(source).split("/")[-1] # Default is a tensorflow model file if file_name == "keras_metadata.pb": machine_learning_library_name = "Keras" - operators_in_model = SavedModelScan._get_keras_pb_operator_names( - data, source + operators_in_model = self._get_keras_pb_operator_names( + data=data, source=source ) else: machine_learning_library_name = "Tensorflow" - operators_in_model = SavedModelScan._get_tensorflow_operator_names( - data=data - ) + operators_in_model = self._get_tensorflow_operator_names(data=data) return SavedModelScan._check_for_unsafe_tf_keras_operator( - machine_learning_library_name, operators_in_model, source + machine_learning_library_name, operators_in_model, source, self._settings ) @staticmethod @@ -92,8 +87,7 @@ def _get_keras_pb_operator_names( return [] - @staticmethod - def _get_tensorflow_operator_names(data: IO[bytes]) -> List[str]: + def _get_tensorflow_operator_names(self, data: IO[bytes]) -> List[str]: saved_model = SavedModel() saved_model.ParseFromString(data.read()) @@ -112,13 +106,15 @@ def _get_tensorflow_operator_names(data: IO[bytes]) -> List[str]: # This function checks for malicious operators in both Keras and Tensorflow @staticmethod def _check_for_unsafe_tf_keras_operator( - module_name: str, raw_operator: List[str], source: Union[str, Path] - ) -> Tuple[List[Issue], List[Error]]: - unsafe_operators: Dict[str, IssueSeverity] = { - "ReadFile": IssueSeverity.HIGH, - "WriteFile": IssueSeverity.HIGH, - "Lambda": IssueSeverity.MEDIUM, - } + module_name: str, + raw_operator: List[str], + source: Union[str, Path], + settings: Dict[str, Any], + ) -> ScanResults: + unsafe_operators: Dict[str, IssueSeverity] = settings[ + SavedModelScan.full_name() + ]["unsafe_tf_keras_operators"] + issues: List[Issue] = [] all_operators = tensorflow.raw_ops.__dict__.keys() all_safe_operators = [ @@ -142,12 +138,23 @@ def _check_for_unsafe_tf_keras_operator( ), ) ) - return issues, [] - - @staticmethod - def supported_extensions() -> List[str]: - return [".pb"] + return ScanResults(issues, []) @staticmethod def name() -> str: return "saved_model" + + @staticmethod + def full_name() -> str: + return "modelscan.scanners.SavedModelScan" + + @staticmethod + def handle_binary_dependencies( + settings: Optional[Dict[str, Any]] = None + ) -> Optional[ModelScanError]: + if not tensorflow_installed: + return ModelScanError( + SavedModelScan.name(), + f"To use {SavedModelScan.full_name()}, please install modelscan with tensorflow extras. 'pip install \"modelscan\[tensorflow]\"' if you are using pip.", + ) + return None diff --git a/modelscan/scanners/scan.py b/modelscan/scanners/scan.py new file mode 100644 index 0000000..47aaf40 --- /dev/null +++ b/modelscan/scanners/scan.py @@ -0,0 +1,59 @@ +import abc +from pathlib import Path +from typing import List, Union, Optional, IO, Any, Dict + +from modelscan.error import Error, ModelScanError +from modelscan.issues import Issue + + +class ScanResults: + issues: List[Issue] + errors: List[Error] + + def __init__(self, issues: List[Issue], errors: List[Error]) -> None: + self.issues = issues + self.errors = errors + + +class ScanBase(metaclass=abc.ABCMeta): + def __init__( + self, + settings: Dict[str, Any], + ) -> None: + self._settings: Dict[str, Any] = settings + + @staticmethod + @abc.abstractmethod + def name() -> str: + raise NotImplementedError + + @staticmethod + @abc.abstractmethod + def full_name() -> str: + raise NotImplementedError + + @abc.abstractmethod + def scan( + self, + source: Union[str, Path], + data: Optional[IO[bytes]] = None, + ) -> Optional[ScanResults]: + raise NotImplementedError + + @staticmethod + def handle_binary_dependencies( + settings: Optional[Dict[str, Any]] = None + ) -> Optional[ModelScanError]: + """ + Implement this method if the plugin requires a binary dependency. + It should perform the following actions: + + 1. Check if the dependency is installed + 2. Return a ModelScanError prompting the install if not + """ + return None + + def label_results(self, results: ScanResults) -> ScanResults: + for issue in results.issues: + issue.details.scanner = self.full_name() + return results diff --git a/modelscan/settings.py b/modelscan/settings.py new file mode 100644 index 0000000..8b959fd --- /dev/null +++ b/modelscan/settings.py @@ -0,0 +1,89 @@ +from modelscan.issues import IssueSeverity + +DEFAULT_SCANNERS = [ + "modelscan.scanners.H5Scan", + "modelscan.scanners.KerasScan", + "modelscan.scanners.SavedModelScan", + "modelscan.scanners.NumpyScan", + "modelscan.scanners.PickleScan", + "modelscan.scanners.PyTorchScan", +] + +DEFAULT_SETTINGS = { + "supported_zip_extensions": [".zip", ".npz"], + "scanners": { + "modelscan.scanners.H5Scan": { + "enabled": True, + "supported_extensions": [".h5"], + }, + "modelscan.scanners.KerasScan": { + "enabled": True, + "supported_extensions": [".keras"], + }, + "modelscan.scanners.SavedModelScan": { + "enabled": True, + "supported_extensions": [".pb"], + "unsafe_tf_keras_operators": { + "ReadFile": IssueSeverity.HIGH, + "WriteFile": IssueSeverity.HIGH, + "Lambda": IssueSeverity.MEDIUM, + }, + }, + "modelscan.scanners.NumpyScan": { + "enabled": True, + "supported_extensions": [".npy"], + }, + "modelscan.scanners.PickleScan": { + "enabled": True, + "supported_extensions": [ + ".pkl", + ".pickle", + ".joblib", + ".dill", + ".dat", + ".data", + ], + }, + "modelscan.scanners.PyTorchScan": { + "enabled": True, + "supported_extensions": [".bin", ".pt", ".pth", ".ckpt"], + }, + "unsafe_globals": { + "CRITICAL": { + "__builtin__": { + "eval", + "compile", + "getattr", + "apply", + "exec", + "open", + "breakpoint", + }, # Pickle versions 0, 1, 2 have those function under '__builtin__' + "builtins": { + "eval", + "compile", + "getattr", + "apply", + "exec", + "open", + "breakpoint", + }, # Pickle versions 3, 4 have those function under 'builtins' + "runpy": "*", + "os": "*", + "nt": "*", # Alias for 'os' on Windows. Includes os.system() + "posix": "*", # Alias for 'os' on Linux. Includes os.system() + "socket": "*", + "subprocess": "*", + "sys": "*", + }, + "HIGH": { + "webbrowser": "*", # Includes webbrowser.open() + "httplib": "*", # Includes http.client.HTTPSConnection() + "requests.api": "*", + "aiohttp.client": "*", + }, + "MEDIUM": {}, + "LOW": {}, + }, + }, +} diff --git a/modelscan/tools/picklescanner.py b/modelscan/tools/picklescanner.py index 8a37296..2a96c31 100644 --- a/modelscan/tools/picklescanner.py +++ b/modelscan/tools/picklescanner.py @@ -9,6 +9,7 @@ from modelscan.error import Error, ModelScanError from modelscan.issues import Issue, IssueCode, IssueSeverity, OperatorIssueDetails +from modelscan.scanners.scan import ScanResults logger = logging.getLogger("modelscan") @@ -24,68 +25,6 @@ def __str__(self) -> str: return self.msg -_safe_globals: Dict[str, Set[str]] = { - "collections": {"OrderedDict"}, - "torch": { - "LongStorage", - "FloatStorage", - "HalfStorage", - "QUInt2x4Storage", - "QUInt4x2Storage", - "QInt32Storage", - "QInt8Storage", - "QUInt8Storage", - "ComplexFloatStorage", - "ComplexDoubleStorage", - "DoubleStorage", - "BFloat16Storage", - "BoolStorage", - "CharStorage", - "ShortStorage", - "IntStorage", - "ByteStorage", - }, - "torch._utils": {"_rebuild_tensor_v2"}, -} - -_unsafe_globals: Dict[str, Any] = { - "CRITICAL": { - "__builtin__": { - "eval", - "compile", - "getattr", - "apply", - "exec", - "open", - "breakpoint", - }, # Pickle versions 0, 1, 2 have those function under '__builtin__' - "builtins": { - "eval", - "compile", - "getattr", - "apply", - "exec", - "open", - "breakpoint", - }, # Pickle versions 3, 4 have those function under 'builtins' - "runpy": "*", - "os": "*", - "nt": "*", # Alias for 'os' on Windows. Includes os.system() - "posix": "*", # Alias for 'os' on Linux. Includes os.system() - "socket": "*", - "subprocess": "*", - "sys": "*", - }, - "HIGH": { - "webbrowser": "*", # Includes webbrowser.open() - "httplib": "*", # Includes http.client.HTTPSConnection() - "requests.api": "*", - "aiohttp.client": "*", - }, - "MEDIUM": {}, - "LOW": {}, -} - # # TODO: handle methods loading other Pickle files (either mark as suspicious, or follow calls to scan other files [preventing infinite loops]) # @@ -178,27 +117,31 @@ def _list_globals( def scan_pickle_bytes( data: IO[bytes], source: Union[Path, str], + settings: Dict[str, Any], scan_name: str = "pickle", multiple_pickles: bool = True, -) -> Tuple[List[Issue], List[Error]]: +) -> ScanResults: """Disassemble a Pickle stream and report issues""" issues: List[Issue] = [] try: raw_globals = _list_globals(data, multiple_pickles) except GenOpsError as e: - return issues, [ - ModelScanError(scan_name, f"Error parsing pickle file {source}: {e}") - ] + return ScanResults( + issues, + [ModelScanError(scan_name, f"Error parsing pickle file {source}: {e}")], + ) logger.debug("Global imports in %s: %s", source, raw_globals) for rg in raw_globals: global_module, global_name, severity = rg[0], rg[1], None - unsafe_critical_filter = _unsafe_globals["CRITICAL"].get(global_module) - unsafe_high_filter = _unsafe_globals["HIGH"].get(global_module) - unsafe_medium_filter = _unsafe_globals["MEDIUM"].get(global_module) - unsafe_low_filter = _unsafe_globals["LOW"].get(global_module) + unsafe_critical_filter = settings["unsafe_globals"]["CRITICAL"].get( + global_module + ) + unsafe_high_filter = settings["unsafe_globals"]["HIGH"].get(global_module) + unsafe_medium_filter = settings["unsafe_globals"]["MEDIUM"].get(global_module) + unsafe_low_filter = settings["unsafe_globals"]["LOW"].get(global_module) if unsafe_critical_filter is not None and ( unsafe_critical_filter == "*" or global_name in unsafe_critical_filter ): @@ -229,12 +172,12 @@ def scan_pickle_bytes( ), ) ) - return issues, [] + return ScanResults(issues, []) def scan_numpy( - data: IO[bytes], source: Union[str, Path] -) -> Tuple[List[Issue], List[Error]]: + data: IO[bytes], source: Union[str, Path], settings: Dict[str, Any] +) -> ScanResults: # Code to distinguish from NumPy binary files and pickles. _ZIP_PREFIX = b"PK\x03\x04" _ZIP_SUFFIX = b"PK\x05\x06" # empty zip files start with this @@ -253,16 +196,16 @@ def scan_numpy( _, _, dtype = np.lib.format._read_array_header(data, version) # type: ignore[attr-defined] if dtype.hasobject: - return scan_pickle_bytes(data, source, "numpy") + return scan_pickle_bytes(data, source, settings, "numpy") else: - return [], [] + return ScanResults([], []) else: - return scan_pickle_bytes(data, source, "numpy") + return scan_pickle_bytes(data, source, settings, "numpy") def scan_pytorch( - data: IO[bytes], source: Union[str, Path] -) -> Tuple[List[Issue], List[Error]]: + data: IO[bytes], source: Union[str, Path], settings: Dict[str, Any] +) -> ScanResults: should_read_directly = _should_read_directly(data) if should_read_directly and data.tell() == 0: # try loading from tar @@ -275,7 +218,7 @@ def scan_pytorch( magic = get_magic_number(data) if magic != MAGIC_NUMBER: - return [], [ - ModelScanError("pytorch", f"Invalid magic number for file {source}") - ] - return scan_pickle_bytes(data, source, "pytorch", multiple_pickles=False) + return ScanResults( + [], [ModelScanError("pytorch", f"Invalid magic number for file {source}")] + ) + return scan_pickle_bytes(data, source, settings, "pytorch", multiple_pickles=False) diff --git a/tests/test_modelscan.py b/tests/test_modelscan.py index a82c105..805c2e7 100644 --- a/tests/test_modelscan.py +++ b/tests/test_modelscan.py @@ -14,18 +14,26 @@ import sys import tensorflow as tf from tensorflow import keras -from typing import Any, List, Set +from typing import Any, List, Set, Dict from test_utils import generate_dill_unsafe_file, generate_unsafe_pickle_file import zipfile -from modelscan.modelscan import Modelscan +from modelscan.modelscan import ModelScan from modelscan.cli import cli -from modelscan.error import ModelScanError -from modelscan.issues import Issue, IssueCode, IssueSeverity, OperatorIssueDetails +from modelscan.issues import ( + Issue, + IssueCode, + IssueSeverity, + OperatorIssueDetails, +) from modelscan.tools.picklescanner import ( scan_pickle_bytes, scan_numpy, ) +from modelscan.scanners.saved_model.scan import SavedModelScan +from modelscan.settings import DEFAULT_SETTINGS + +settings: Dict[str, Any] = DEFAULT_SETTINGS["scanners"] # type: ignore[assignment] class Malicious1: @@ -271,7 +279,9 @@ def test_scan_pickle_bytes() -> None: ) ] assert ( - scan_pickle_bytes(io.BytesIO(pickle.dumps(Malicious1())), "file.pkl")[0] + scan_pickle_bytes( + io.BytesIO(pickle.dumps(Malicious1())), "file.pkl", settings + ).issues == expected ) @@ -287,21 +297,21 @@ def test_scan_zip(zip_file_path: Any) -> None: ) ] - ms = Modelscan() + ms = ModelScan() ms._scan_zip(f"{zip_file_path}/test.zip") assert ms.issues.all_issues == expected def test_scan_pytorch(pytorch_file_path: Any) -> None: - bad_pytorch = Modelscan() - bad_pytorch.scan_path(Path(f"{pytorch_file_path}/bad_pytorch.pt")) + bad_pytorch = ModelScan() + bad_pytorch.scan(Path(f"{pytorch_file_path}/bad_pytorch.pt")) assert bad_pytorch.issues.all_issues == [] assert [error.scan_name for error in bad_pytorch.errors] == ["pytorch"] # type: ignore[attr-defined] def test_scan_numpy(numpy_file_path: Any) -> None: with open(f"{numpy_file_path}/safe_numpy.npy", "rb") as f: - assert scan_numpy(io.BytesIO(f.read()), "safe_numpy.npy")[0] == [] + assert scan_numpy(io.BytesIO(f.read()), "safe_numpy.npy", settings).issues == [] expected = { Issue( @@ -313,19 +323,20 @@ def test_scan_numpy(numpy_file_path: Any) -> None: with open(f"{numpy_file_path}/unsafe_numpy.npy", "rb") as f: compare_results( - scan_numpy(io.BytesIO(f.read()), "unsafe_numpy.npy")[0], expected + scan_numpy(io.BytesIO(f.read()), "unsafe_numpy.npy", settings).issues, + expected, ) def test_scan_file_path(file_path: Any) -> None: - benign_pickle = Modelscan() - benign_pickle.scan_path(Path(f"{file_path}/data/benign0_v3.pkl")) - benign_dill = Modelscan() - benign_dill.scan_path(Path(f"{file_path}/data/benign0_v3.dill")) + benign_pickle = ModelScan() + benign_pickle.scan(Path(f"{file_path}/data/benign0_v3.pkl")) + benign_dill = ModelScan() + benign_dill.scan(Path(f"{file_path}/data/benign0_v3.dill")) assert benign_pickle.issues.all_issues == [] assert benign_dill.issues.all_issues == [] - malicious0 = Modelscan() + malicious0 = ModelScan() expected_malicious0 = { Issue( IssueCode.UNSAFE_OPERATOR, @@ -356,7 +367,7 @@ def test_scan_file_path(file_path: Any) -> None: ), ), } - malicious0.scan_path(Path(f"{file_path}/data/malicious0.pkl")) + malicious0.scan(Path(f"{file_path}/data/malicious0.pkl")) compare_results(malicious0.issues.all_issues, expected_malicious0) @@ -427,21 +438,21 @@ def test_scan_pickle_operators(file_path: Any) -> None: ), ) ] - malicious1_v0 = Modelscan() - malicious1_v3 = Modelscan() - malicious1_v4 = Modelscan() - malicious1_v0_dill = Modelscan() - malicious1_v3_dill = Modelscan() - malicious1_v4_dill = Modelscan() - - malicious1 = Modelscan() - malicious1_v0.scan_path(Path(f"{file_path}/data/malicious1_v0.pkl")) - malicious1_v3.scan_path(Path(f"{file_path}/data/malicious1_v3.pkl")) - malicious1_v4.scan_path(Path(f"{file_path}/data/malicious1_v4.pkl")) - malicious1_v0_dill.scan_path(Path(f"{file_path}/data/malicious1_v0.dill")) - malicious1_v3_dill.scan_path(Path(f"{file_path}/data/malicious1_v3.dill")) - malicious1_v4_dill.scan_path(Path(f"{file_path}/data/malicious1_v4.dill")) - malicious1.scan_path(Path(f"{file_path}/data/malicious1.zip")) + malicious1_v0 = ModelScan() + malicious1_v3 = ModelScan() + malicious1_v4 = ModelScan() + malicious1_v0_dill = ModelScan() + malicious1_v3_dill = ModelScan() + malicious1_v4_dill = ModelScan() + + malicious1 = ModelScan() + malicious1_v0.scan(Path(f"{file_path}/data/malicious1_v0.pkl")) + malicious1_v3.scan(Path(f"{file_path}/data/malicious1_v3.pkl")) + malicious1_v4.scan(Path(f"{file_path}/data/malicious1_v4.pkl")) + malicious1_v0_dill.scan(Path(f"{file_path}/data/malicious1_v0.dill")) + malicious1_v3_dill.scan(Path(f"{file_path}/data/malicious1_v3.dill")) + malicious1_v4_dill.scan(Path(f"{file_path}/data/malicious1_v4.dill")) + malicious1.scan(Path(f"{file_path}/data/malicious1.zip")) assert malicious1_v0.issues.all_issues == expected_malicious1_v0 assert malicious1_v3.issues.all_issues == expected_malicious1_v3 assert malicious1_v4.issues.all_issues == expected_malicious1_v4 @@ -477,12 +488,12 @@ def test_scan_pickle_operators(file_path: Any) -> None: ), ) ] - malicious2_v0 = Modelscan() - malicious2_v3 = Modelscan() - malicious2_v4 = Modelscan() - malicious2_v0.scan_path(Path(f"{file_path}/data/malicious2_v0.pkl")) - malicious2_v3.scan_path(Path(f"{file_path}/data/malicious2_v3.pkl")) - malicious2_v4.scan_path(Path(f"{file_path}/data/malicious2_v4.pkl")) + malicious2_v0 = ModelScan() + malicious2_v3 = ModelScan() + malicious2_v4 = ModelScan() + malicious2_v0.scan(Path(f"{file_path}/data/malicious2_v0.pkl")) + malicious2_v3.scan(Path(f"{file_path}/data/malicious2_v3.pkl")) + malicious2_v4.scan(Path(f"{file_path}/data/malicious2_v4.pkl")) assert malicious2_v0.issues.all_issues == expected_malicious2_v0 assert malicious2_v3.issues.all_issues == expected_malicious2_v3 assert malicious2_v4.issues.all_issues == expected_malicious2_v4 @@ -498,8 +509,8 @@ def test_scan_pickle_operators(file_path: Any) -> None: ), ) ] - malicious3 = Modelscan() - malicious3.scan_path(Path(f"{file_path}/data/malicious3.pkl")) + malicious3 = ModelScan() + malicious3.scan(Path(f"{file_path}/data/malicious3.pkl")) assert malicious3.issues.all_issues == expected_malicious3 expected_malicious4 = [ @@ -511,8 +522,8 @@ def test_scan_pickle_operators(file_path: Any) -> None: ), ) ] - malicious4 = Modelscan() - malicious4.scan_path(Path(f"{file_path}/data/malicious4.pickle")) + malicious4 = ModelScan() + malicious4.scan(Path(f"{file_path}/data/malicious4.pickle")) assert malicious4.issues.all_issues == expected_malicious4 expected_malicious5 = [ @@ -526,8 +537,8 @@ def test_scan_pickle_operators(file_path: Any) -> None: ), ) ] - malicious5 = Modelscan() - malicious5.scan_path(Path(f"{file_path}/data/malicious5.pickle")) + malicious5 = ModelScan() + malicious5.scan(Path(f"{file_path}/data/malicious5.pickle")) assert malicious5.issues.all_issues == expected_malicious5 expected_malicious6 = [ @@ -539,8 +550,8 @@ def test_scan_pickle_operators(file_path: Any) -> None: ), ) ] - malicious6 = Modelscan() - malicious6.scan_path(Path(f"{file_path}/data/malicious6.pkl")) + malicious6 = ModelScan() + malicious6.scan(Path(f"{file_path}/data/malicious6.pkl")) assert malicious6.issues.all_issues == expected_malicious6 expected_malicious7 = [ @@ -552,8 +563,8 @@ def test_scan_pickle_operators(file_path: Any) -> None: ), ) ] - malicious7 = Modelscan() - malicious7.scan_path(Path(f"{file_path}/data/malicious7.pkl")) + malicious7 = ModelScan() + malicious7.scan(Path(f"{file_path}/data/malicious7.pkl")) assert malicious7.issues.all_issues == expected_malicious7 expected_malicious8 = [ @@ -565,8 +576,8 @@ def test_scan_pickle_operators(file_path: Any) -> None: ), ) ] - malicious8 = Modelscan() - malicious8.scan_path(Path(f"{file_path}/data/malicious8.pkl")) + malicious8 = ModelScan() + malicious8.scan(Path(f"{file_path}/data/malicious8.pkl")) assert malicious8.issues.all_issues == expected_malicious8 expected_malicious9 = [ @@ -576,8 +587,8 @@ def test_scan_pickle_operators(file_path: Any) -> None: OperatorIssueDetails("sys", "exit", f"{file_path}/data/malicious9.pkl"), ) ] - malicious9 = Modelscan() - malicious9.scan_path(Path(f"{file_path}/data/malicious9.pkl")) + malicious9 = ModelScan() + malicious9.scan(Path(f"{file_path}/data/malicious9.pkl")) assert malicious9.issues.all_issues == expected_malicious9 expected_malicious10 = [ @@ -589,8 +600,8 @@ def test_scan_pickle_operators(file_path: Any) -> None: ), ) ] - malicious10 = Modelscan() - malicious10.scan_path(Path(f"{file_path}/data/malicious10.pkl")) + malicious10 = ModelScan() + malicious10.scan(Path(f"{file_path}/data/malicious10.pkl")) assert malicious10.issues.all_issues == expected_malicious10 @@ -752,16 +763,16 @@ def test_scan_directory_path(file_path: str) -> None: ), ), } - ms = Modelscan() + ms = ModelScan() p = Path(f"{file_path}/data/") - ms.scan_path(p) + ms.scan(p) compare_results(ms.issues.all_issues, expected) @pytest.mark.parametrize("file_extension", [".h5", ".keras"], ids=["h5", "keras"]) def test_scan_keras(keras_file_path: Any, file_extension: str) -> None: - ms = Modelscan() - ms.scan_path(Path(f"{keras_file_path}/safe{file_extension}")) + ms = ModelScan() + ms.scan(Path(f"{keras_file_path}/safe{file_extension}")) assert ms.issues.all_issues == [] if file_extension == ".keras": @@ -787,7 +798,6 @@ def test_scan_keras(keras_file_path: Any, file_extension: str) -> None: ] ms._scan_source( Path(f"{keras_file_path}/unsafe{file_extension}"), - extension=file_extension, ) else: expected = [ @@ -810,7 +820,7 @@ def test_scan_keras(keras_file_path: Any, file_extension: str) -> None: ), ), ] - ms.scan_path(Path(f"{keras_file_path}/unsafe{file_extension}")) + ms._scan_path(Path(f"{keras_file_path}/unsafe{file_extension}")) assert ms.issues.all_issues == expected