From 705d17fce14ec20d6ee7eca92d01efcb022f40d6 Mon Sep 17 00:00:00 2001 From: Oleksandr Yaremchuk Date: Thu, 14 Mar 2024 16:35:51 +0100 Subject: [PATCH] * use model in each scanner --- modelscan/model.py | 39 ++++++--- modelscan/modelscan.py | 32 ++++---- modelscan/scanners/h5/scan.py | 47 ++++++----- modelscan/scanners/keras/scan.py | 48 ++++++----- modelscan/scanners/pickle/scan.py | 69 +++++++++------- modelscan/scanners/saved_model/scan.py | 109 ++++++++++++------------- modelscan/scanners/scan.py | 7 +- modelscan/tools/picklescanner.py | 50 +++++------- tests/test_modelscan.py | 10 +-- 9 files changed, 210 insertions(+), 201 deletions(-) diff --git a/modelscan/model.py b/modelscan/model.py index 80ee9f0..a768882 100644 --- a/modelscan/model.py +++ b/modelscan/model.py @@ -2,13 +2,16 @@ from typing import List, Union, Optional, IO, Generator from modelscan.tools.utils import _is_zipfile import zipfile -from dataclasses import dataclass class ModelPathNotValid(ValueError): pass +class ModelDataEmpty(ValueError): + pass + + class ModelBadZip(ValueError): def __init__(self, e: zipfile.BadZipFile, source: str): self.source = source @@ -16,12 +19,12 @@ def __init__(self, e: zipfile.BadZipFile, source: str): class Model: - source: Path - data: Optional[IO[bytes]] = None + _source: Path + _data: Optional[IO[bytes]] = None def __init__(self, source: Union[str, Path], data: Optional[IO[bytes]] = None): - self.source = Path(source) - self.data = data + self._source = Path(source) + self._data = data @staticmethod def from_path(path: Path) -> "Model": @@ -31,8 +34,8 @@ def from_path(path: Path) -> "Model": return Model(path) def get_files(self) -> Generator["Model", None, None]: - if Path.is_dir(self.source): - for f in Path(self.source).rglob("*"): + if Path.is_dir(self._source): + for f in Path(self._source).rglob("*"): if Path.is_file(f): yield Model(f) @@ -40,16 +43,28 @@ def get_zip_files( self, supported_extensions: List[str] ) -> Generator["Model", None, None]: if ( - not _is_zipfile(self.source) - and Path(self.source).suffix not in supported_extensions + not _is_zipfile(self._source) + and Path(self._source).suffix not in supported_extensions ): return try: - with zipfile.ZipFile(self.source, "r") as zip: + with zipfile.ZipFile(self._source, "r") as zip: file_names = zip.namelist() for file_name in file_names: with zip.open(file_name, "r") as file_io: - yield Model(f"{self.source}:{file_name}", file_io) + yield Model(f"{self._source}:{file_name}", file_io) except zipfile.BadZipFile as e: - raise ModelBadZip(e, f"{self.source}:{file_name}") + raise ModelBadZip(e, f"{self._source}:{file_name}") + + def get_source(self) -> Path: + return self._source + + def has_data(self) -> bool: + return self._data is not None + + def get_data(self) -> IO[bytes]: + if not self._data: + raise ModelDataEmpty("Model data is empty.") + + return self._data diff --git a/modelscan/modelscan.py b/modelscan/modelscan.py index d37454c..dddbabf 100644 --- a/modelscan/modelscan.py +++ b/modelscan/modelscan.py @@ -108,7 +108,9 @@ def _scan_model( ) except ModelBadZip as e: logger.debug( - f"Skipping zip file {model.source}, due to error", e, exc_info=True + f"Skipping zip file {str(model.get_source())}, due to error", + e, + exc_info=True, ) self._skipped.append( ModelScanSkipped( @@ -125,38 +127,43 @@ def _scan_model( has_extracted = True scanned = self._scan_source(extracted_model) if not scanned: - if _is_zipfile(extracted_model.source, data=extracted_model.data): + if _is_zipfile( + extracted_model.get_source(), + data=extracted_model.get_data() + if extracted_model.has_data() + else None, + ): self._errors.append( ModelScanError( "ModelScan", ErrorCategories.NESTED_ZIP, "ModelScan does not support nested zip files.", - str(extracted_model.source), + str(extracted_model.get_source()), ) ) # check if added to skipped already all_skipped_files = [skipped.source for skipped in self._skipped] - if str(extracted_model.source) not in all_skipped_files: + if str(extracted_model.get_source()) not in all_skipped_files: self._skipped.append( ModelScanSkipped( "ModelScan", SkipCategories.SCAN_NOT_SUPPORTED, f"Model Scan did not scan file", - str(extracted_model.source), + str(extracted_model.get_source()), ) ) if not scanned and not has_extracted: # check if added to skipped already all_skipped_files = [skipped.source for skipped in self._skipped] - if str(model.source) not in all_skipped_files: + if str(model.get_source()) not in all_skipped_files: self._skipped.append( ModelScanSkipped( "ModelScan", SkipCategories.SCAN_NOT_SUPPORTED, f"Model Scan did not scan file", - str(model.source), + str(model.get_source()), ) ) @@ -167,26 +174,23 @@ def _scan_source( scanned = False for scan_class in self._scanners_to_run: scanner = scan_class(self._settings) # type: ignore[operator] - scan_results = scanner.scan( - source=model.source, - data=model.data, - ) + scan_results = scanner.scan(model) if scan_results is not None: scanned = True logger.info( - f"Scanning {model.source} using {scanner.full_name()} model scan" + f"Scanning {model.get_source()} using {scanner.full_name()} model scan" ) if scan_results.errors: self._errors.extend(scan_results.errors) elif scan_results.issues: - self._scanned.append(str(model.source)) + self._scanned.append(str(model.get_source())) self._issues.add_issues(scan_results.issues) elif scan_results.skipped: self._skipped.extend(scan_results.skipped) else: - self._scanned.append(str(model.source)) + self._scanned.append(str(model.get_source())) return scanned diff --git a/modelscan/scanners/h5/scan.py b/modelscan/scanners/h5/scan.py index ee350fe..a9fc566 100644 --- a/modelscan/scanners/h5/scan.py +++ b/modelscan/scanners/h5/scan.py @@ -1,7 +1,6 @@ import json import logging -from pathlib import Path -from typing import IO, List, Union, Optional, Dict, Any +from typing import List, Optional, Dict, Any try: @@ -15,6 +14,7 @@ from modelscan.skip import ModelScanSkipped, SkipCategories from modelscan.scanners.scan import ScanResults from modelscan.scanners.saved_model.scan import SavedModelLambdaDetectScan +from modelscan.model import Model logger = logging.getLogger("modelscan") @@ -22,11 +22,10 @@ class H5LambdaDetectScan(SavedModelLambdaDetectScan): def scan( self, - source: Union[str, Path], - data: Optional[IO[bytes]] = None, + model: Model, ) -> Optional[ScanResults]: if ( - not Path(source).suffix + not model.get_source().suffix in self._settings["scanners"][H5LambdaDetectScan.full_name()][ "supported_extensions" ] @@ -46,7 +45,7 @@ def scan( [], ) - if data: + if model.has_data(): logger.warning( f"{self.full_name()} got data bytes. It only support direct file scanning." ) @@ -58,21 +57,21 @@ def scan( self.name(), SkipCategories.H5_DATA, f"{self.full_name()} got data bytes. It only support direct file scanning.", - str(source), + str(model.get_source()), ) ], ) - results = self._scan_keras_h5_file(source) + results = self._scan_keras_h5_file(model) if results: return self.label_results(results) - else: - return None - def _scan_keras_h5_file(self, source: Union[str, Path]) -> Optional[ScanResults]: + return None + + def _scan_keras_h5_file(self, model: Model) -> Optional[ScanResults]: machine_learning_library_name = "Keras" - if self._check_model_config(source): - operators_in_model = self._get_keras_h5_operator_names(source) + if self._check_model_config(model): + operators_in_model = self._get_keras_h5_operator_names(model) if operators_in_model is None: return None @@ -84,7 +83,7 @@ def _scan_keras_h5_file(self, source: Union[str, Path]) -> Optional[ScanResults] self.name(), ErrorCategories.JSON_DECODE, f"Not a valid JSON data", - str(source), + str(model.get_source()), ) ], [], @@ -92,7 +91,7 @@ def _scan_keras_h5_file(self, source: Union[str, Path]) -> Optional[ScanResults] return H5LambdaDetectScan._check_for_unsafe_tf_keras_operator( module_name=machine_learning_library_name, raw_operator=operators_in_model, - source=source, + model=model, unsafe_operators=self._settings["scanners"][ SavedModelLambdaDetectScan.full_name() ]["unsafe_keras_operators"], @@ -106,25 +105,23 @@ def _scan_keras_h5_file(self, source: Union[str, Path]) -> Optional[ScanResults] self.name(), SkipCategories.MODEL_CONFIG, f"Model Config not found", - str(source), + str(model.get_source()), ) ], ) - def _check_model_config(self, source: Union[str, Path]) -> bool: - with h5py.File(source, "r") as model_hdf5: + def _check_model_config(self, model: Model) -> bool: + with h5py.File(model.get_source(), "r") as model_hdf5: if "model_config" in model_hdf5.attrs.keys(): return True else: - logger.error(f"Model Config not found in: {source}") + logger.error(f"Model Config not found in: {model.get_source()}") return False - def _get_keras_h5_operator_names( - self, source: Union[str, Path] - ) -> Optional[List[Any]]: + def _get_keras_h5_operator_names(self, model: Model) -> Optional[List[Any]]: # Todo: source isn't guaranteed to be a file - with h5py.File(source, "r") as model_hdf5: + with h5py.File(model.get_source(), "r") as model_hdf5: try: if not "model_config" in model_hdf5.attrs.keys(): return None @@ -138,7 +135,9 @@ def _get_keras_h5_operator_names( layer.get("config", {}).get("function", {}) ) except json.JSONDecodeError as e: - logger.error(f"Not a valid JSON data from source: {source}, error: {e}") + logger.error( + f"Not a valid JSON data from source: {model.get_source()}, error: {e}" + ) return ["JSONDecodeError"] if lambda_layers: diff --git a/modelscan/scanners/keras/scan.py b/modelscan/scanners/keras/scan.py index 72180da..3491abb 100644 --- a/modelscan/scanners/keras/scan.py +++ b/modelscan/scanners/keras/scan.py @@ -2,26 +2,23 @@ import zipfile import logging from pathlib import Path -from typing import IO, List, Union, Optional, Any +from typing import IO, List, Union, Optional from modelscan.error import ModelScanError, ErrorCategories from modelscan.skip import ModelScanSkipped, SkipCategories from modelscan.scanners.scan import ScanResults from modelscan.scanners.saved_model.scan import SavedModelLambdaDetectScan +from modelscan.model import Model logger = logging.getLogger("modelscan") class KerasLambdaDetectScan(SavedModelLambdaDetectScan): - def scan( - self, - source: Union[str, Path], - data: Optional[IO[bytes]] = None, - ) -> Optional[ScanResults]: + def scan(self, model: Model) -> Optional[ScanResults]: if ( - not Path(source).suffix + not model.get_source().suffix in self._settings["scanners"][KerasLambdaDetectScan.full_name()][ "supported_extensions" ] @@ -43,16 +40,19 @@ def scan( ) try: - with zipfile.ZipFile(data or source, "r") as zip: + source = model.get_source() + if model.has_data(): + source = model.get_data() # type: ignore + with zipfile.ZipFile(source, "r") as zip: file_names = zip.namelist() for file_name in file_names: if file_name == "config.json": with zip.open(file_name, "r") as config_file: + model = Model( + f"{model.get_source()}:{file_name}", config_file + ) return self.label_results( - self._scan_keras_config_file( - source=f"{source}:{file_name}", - config_file=config_file, - ) + self._scan_keras_config_file(model) ) except zipfile.BadZipFile as e: return ScanResults( @@ -63,7 +63,7 @@ def scan( self.name(), SkipCategories.BAD_ZIP, f"Skipping zip file due to error: {e}", - f"{source}:{file_name}", + f"{model.get_source()}:{file_name}", ) ], ) @@ -76,20 +76,18 @@ def scan( self.name(), ErrorCategories.MODEL_SCAN, # Giving a generic error category as this return is added to pass mypy f"Unable to scan .keras file", # Not sure if this is a representative message for ModelScanError - str(source), + str(model.get_source()), ) ], [], ) - def _scan_keras_config_file( - self, source: Union[str, Path], config_file: IO[bytes] - ) -> ScanResults: + def _scan_keras_config_file(self, model: Model) -> ScanResults: machine_learning_library_name = "Keras" # if self._check_json_data(source, config_file): - operators_in_model = self._get_keras_operator_names(source, config_file) + operators_in_model = self._get_keras_operator_names(model) if operators_in_model: if "JSONDecodeError" in operators_in_model: return ScanResults( @@ -99,7 +97,7 @@ def _scan_keras_config_file( self.name(), ErrorCategories.JSON_DECODE, f"Not a valid JSON data", - str(source), + str(model.get_source()), ) ], [], @@ -108,7 +106,7 @@ def _scan_keras_config_file( return KerasLambdaDetectScan._check_for_unsafe_tf_keras_operator( module_name=machine_learning_library_name, raw_operator=operators_in_model, - source=source, + model=model, unsafe_operators=self._settings["scanners"][ SavedModelLambdaDetectScan.full_name() ]["unsafe_keras_operators"], @@ -121,11 +119,9 @@ def _scan_keras_config_file( [], ) - def _get_keras_operator_names( - self, source: Union[str, Path], data: IO[bytes] - ) -> List[str]: + def _get_keras_operator_names(self, model: Model) -> List[str]: try: - model_config_data = json.load(data) + model_config_data = json.load(model.get_data()) lambda_layers = [ layer.get("config", {}).get("function", {}) @@ -136,7 +132,9 @@ def _get_keras_operator_names( return ["Lambda"] * len(lambda_layers) except json.JSONDecodeError as e: - logger.error(f"Not a valid JSON data from source: {source}, error: {e}") + logger.error( + f"Not a valid JSON data from source: {model.get_source()}, error: {e}" + ) return ["JSONDecodeError"] return [] diff --git a/modelscan/scanners/pickle/scan.py b/modelscan/scanners/pickle/scan.py index 2c793ae..3da5f45 100644 --- a/modelscan/scanners/pickle/scan.py +++ b/modelscan/scanners/pickle/scan.py @@ -1,6 +1,5 @@ import logging -from pathlib import Path -from typing import IO, Union, Optional +from typing import Optional from modelscan.scanners.scan import ScanBase, ScanResults from modelscan.tools.utils import _is_zipfile @@ -9,6 +8,7 @@ scan_pickle_bytes, scan_pytorch, ) +from modelscan.model import Model logger = logging.getLogger("modelscan") @@ -16,28 +16,32 @@ class PyTorchUnsafeOpScan(ScanBase): def scan( self, - source: Union[str, Path], - data: Optional[IO[bytes]] = None, + model: Model, ) -> Optional[ScanResults]: if ( - not Path(source).suffix + not model.get_source().suffix in self._settings["scanners"][PyTorchUnsafeOpScan.full_name()][ "supported_extensions" ] ): return None - if _is_zipfile(source, data): + if _is_zipfile( + model.get_source(), model.get_data() if model.has_data() else None + ): return None - if data: - results = scan_pytorch(data=data, source=source, settings=self._settings) + if model.has_data(): + results = scan_pytorch( + model=model, + settings=self._settings, + ) + + return self.label_results(results) - else: - with open(source, "rb") as file_io: - results = scan_pytorch( - data=file_io, source=source, settings=self._settings - ) + with open(model.get_source(), "rb") as file_io: + model = Model(model.get_source(), file_io) + results = scan_pytorch(model=model, settings=self._settings) return self.label_results(results) @@ -53,22 +57,27 @@ def full_name() -> str: class NumpyUnsafeOpScan(ScanBase): def scan( self, - source: Union[str, Path], - data: Optional[IO[bytes]] = None, + model: Model, ) -> Optional[ScanResults]: if ( - not Path(source).suffix + not model.get_source().suffix in self._settings["scanners"][NumpyUnsafeOpScan.full_name()][ "supported_extensions" ] ): return None - if data: - results = scan_numpy(data=data, source=source, settings=self._settings) + if model.has_data(): + results = scan_numpy( + model=model, + settings=self._settings, + ) + + return self.label_results(results) - with open(source, "rb") as file_io: - results = scan_numpy(data=file_io, source=source, settings=self._settings) + with open(model.get_source(), "rb") as file_io: + model = Model(model.get_source(), file_io) + results = scan_numpy(model=model, settings=self._settings) return self.label_results(results) @@ -84,27 +93,27 @@ def full_name() -> str: class PickleUnsafeOpScan(ScanBase): def scan( self, - source: Union[str, Path], - data: Optional[IO[bytes]] = None, + model: Model, ) -> Optional[ScanResults]: if ( - not Path(source).suffix + not model.get_source().suffix in self._settings["scanners"][PickleUnsafeOpScan.full_name()][ "supported_extensions" ] ): return None - if data: + if model.has_data(): results = scan_pickle_bytes( - data=data, source=source, settings=self._settings + model=model, + settings=self._settings, ) - else: - with open(source, "rb") as file_io: - results = scan_pickle_bytes( - data=file_io, source=source, settings=self._settings - ) + return self.label_results(results) + + with open(model.get_source(), "rb") as file_io: + model = Model(model.get_source(), file_io) + results = scan_pickle_bytes(model=model, settings=self._settings) return self.label_results(results) diff --git a/modelscan/scanners/saved_model/scan.py b/modelscan/scanners/saved_model/scan.py index 35d9d5b..548865b 100644 --- a/modelscan/scanners/saved_model/scan.py +++ b/modelscan/scanners/saved_model/scan.py @@ -19,6 +19,7 @@ from modelscan.error import ModelScanError, ErrorCategories from modelscan.issues import Issue, IssueCode, IssueSeverity, OperatorIssueDetails from modelscan.scanners.scan import ScanBase, ScanResults +from modelscan.model import Model logger = logging.getLogger("modelscan") @@ -26,11 +27,10 @@ class SavedModelScan(ScanBase): def scan( self, - source: Union[str, Path], - data: Optional[IO[bytes]] = None, + model: Model, ) -> Optional[ScanResults]: if ( - not Path(source).suffix + not model.get_source().suffix in self._settings["scanners"][self.full_name()]["supported_extensions"] ): return None @@ -49,19 +49,18 @@ def scan( [], ) - if data: - results = self._scan(source, data) + if model.has_data(): + results = self._scan(model) - else: - with open(source, "rb") as file_io: - results = self._scan(source, data=file_io) + return self.label_results(results) if results else None - if results: - return self.label_results(results) - else: - return None + with open(model.get_source(), "rb") as file_io: + model = Model(model.get_source(), file_io) + results = self._scan(model) + + return self.label_results(results) if results else None - def _scan(self, source: Union[str, Path], data: IO[bytes]) -> Optional[ScanResults]: + def _scan(self, model: Model) -> Optional[ScanResults]: raise NotImplementedError # This function checks for malicious operators in both Keras and Tensorflow @@ -69,7 +68,7 @@ def _scan(self, source: Union[str, Path], data: IO[bytes]) -> Optional[ScanResul def _check_for_unsafe_tf_keras_operator( module_name: str, raw_operator: List[str], - source: Union[str, Path], + model: Model, unsafe_operators: Dict[str, Any], ) -> ScanResults: issues: List[Issue] = [] @@ -93,7 +92,7 @@ def _check_for_unsafe_tf_keras_operator( details=OperatorIssueDetails( module=module_name, operator=op, - source=source, + source=str(model.get_source()), severity=severity, ), ) @@ -117,44 +116,39 @@ def full_name() -> str: class SavedModelLambdaDetectScan(SavedModelScan): - def _scan(self, source: Union[str, Path], data: IO[bytes]) -> Optional[ScanResults]: - file_name = str(source).split("/")[-1] - if file_name == "keras_metadata.pb": - machine_learning_library_name = "Keras" - operators_in_model = self._get_keras_pb_operator_names( - data=data, source=source - ) - if operators_in_model: - if "JSONDecodeError" in operators_in_model: - return ScanResults( - [], - [ - ModelScanError( - self.name(), - ErrorCategories.JSON_DECODE, - f"Not a valid JSON data", - str(source), - ) - ], - [], - ) + def _scan(self, model: Model) -> Optional[ScanResults]: + file_name = str(model.get_source()).split("/")[-1] + if file_name != "keras_metadata.pb": + return None - return SavedModelScan._check_for_unsafe_tf_keras_operator( - machine_learning_library_name, - operators_in_model, - source, - self._settings["scanners"][self.full_name()]["unsafe_keras_operators"], - ) + machine_learning_library_name = "Keras" + operators_in_model = self._get_keras_pb_operator_names(model) + if operators_in_model: + if "JSONDecodeError" in operators_in_model: + return ScanResults( + [], + [ + ModelScanError( + self.name(), + ErrorCategories.JSON_DECODE, + f"Not a valid JSON data", + str(model.get_source()), + ) + ], + [], + ) - else: - return None + return SavedModelScan._check_for_unsafe_tf_keras_operator( + machine_learning_library_name, + operators_in_model, + model, + self._settings["scanners"][self.full_name()]["unsafe_keras_operators"], + ) @staticmethod - def _get_keras_pb_operator_names( - data: IO[bytes], source: Union[str, Path] - ) -> List[str]: + def _get_keras_pb_operator_names(model: Model) -> List[str]: saved_metadata = SavedMetadata() - saved_metadata.ParseFromString(data.read()) + saved_metadata.ParseFromString(model.get_data().read()) try: lambda_layers = [ @@ -170,7 +164,9 @@ def _get_keras_pb_operator_names( return ["Lambda"] * len(lambda_layers) except json.JSONDecodeError as e: - logger.error(f"Not a valid JSON data from source: {source}, error: {e}") + logger.error( + f"Not a valid JSON data from source: {str(model.get_source())}, error: {e}" + ) return ["JSONDecodeError"] return [] @@ -181,25 +177,24 @@ def full_name() -> str: class SavedModelTensorflowOpScan(SavedModelScan): - def _scan(self, source: Union[str, Path], data: IO[bytes]) -> Optional[ScanResults]: - file_name = str(source).split("/")[-1] + def _scan(self, model: Model) -> Optional[ScanResults]: + file_name = str(model.get_source()).split("/")[-1] if file_name == "keras_metadata.pb": return None - else: - machine_learning_library_name = "Tensorflow" - operators_in_model = self._get_tensorflow_operator_names(data=data) + machine_learning_library_name = "Tensorflow" + operators_in_model = self._get_tensorflow_operator_names(model) return SavedModelScan._check_for_unsafe_tf_keras_operator( machine_learning_library_name, operators_in_model, - source, + model, self._settings["scanners"][self.full_name()]["unsafe_tf_operators"], ) - def _get_tensorflow_operator_names(self, data: IO[bytes]) -> List[str]: + def _get_tensorflow_operator_names(self, model: Model) -> List[str]: saved_model = SavedModel() - saved_model.ParseFromString(data.read()) + saved_model.ParseFromString(model.get_data().read()) model_op_names: Set[str] = set() # Iterate over every metagraph in case there is more than one diff --git a/modelscan/scanners/scan.py b/modelscan/scanners/scan.py index 681ef70..89a5981 100644 --- a/modelscan/scanners/scan.py +++ b/modelscan/scanners/scan.py @@ -1,10 +1,10 @@ import abc -from pathlib import Path -from typing import List, Union, Optional, IO, Any, Dict +from typing import List, Optional, Any, Dict from modelscan.error import ModelScanError from modelscan.skip import ModelScanSkipped from modelscan.issues import Issue +from modelscan.model import Model class ScanResults: @@ -43,8 +43,7 @@ def full_name() -> str: @abc.abstractmethod def scan( self, - source: Union[str, Path], - data: Optional[IO[bytes]] = None, + model: Model, ) -> Optional[ScanResults]: raise NotImplementedError diff --git a/modelscan/tools/picklescanner.py b/modelscan/tools/picklescanner.py index d177288..7738456 100644 --- a/modelscan/tools/picklescanner.py +++ b/modelscan/tools/picklescanner.py @@ -1,7 +1,5 @@ import logging import pickletools # nosec -from dataclasses import dataclass -from pathlib import Path from tarfile import TarError from typing import IO, Any, Dict, List, Set, Tuple, Union @@ -11,6 +9,7 @@ from modelscan.skip import ModelScanSkipped, SkipCategories from modelscan.issues import Issue, IssueCode, IssueSeverity, OperatorIssueDetails from modelscan.scanners.scan import ScanResults +from modelscan.model import Model logger = logging.getLogger("modelscan") @@ -117,17 +116,15 @@ def _list_globals( def scan_pickle_bytes( - data: IO[bytes], - source: Union[Path, str], + model: Model, settings: Dict[str, Any], scan_name: str = "pickle", multiple_pickles: bool = True, ) -> ScanResults: """Disassemble a Pickle stream and report issues""" - issues: List[Issue] = [] try: - raw_globals = _list_globals(data, multiple_pickles) + raw_globals = _list_globals(model.get_data(), multiple_pickles) except GenOpsError as e: return ScanResults( issues, @@ -136,13 +133,13 @@ def scan_pickle_bytes( scan_name, ErrorCategories.PICKLE_GENOPS, f"Parsing error: {e}", - str(source), + str(model.get_source()), ) ], [], ) - logger.debug("Global imports in %s: %s", source, raw_globals) + logger.debug("Global imports in %s: %s", model.get_source(), raw_globals) severities = { "CRITICAL": IssueSeverity.CRITICAL, "HIGH": IssueSeverity.HIGH, @@ -176,7 +173,7 @@ def scan_pickle_bytes( details=OperatorIssueDetails( module=global_module, operator=global_name, - source=source, + source=model.get_source(), severity=severity, ), ) @@ -184,18 +181,16 @@ def scan_pickle_bytes( return ScanResults(issues, [], []) -def scan_numpy( - data: IO[bytes], source: Union[str, Path], settings: Dict[str, Any] -) -> ScanResults: +def scan_numpy(model: Model, settings: Dict[str, Any]) -> ScanResults: scan_name = "numpy" # Code to distinguish from NumPy binary files and pickles. _ZIP_PREFIX = b"PK\x03\x04" _ZIP_SUFFIX = b"PK\x05\x06" # empty zip files start with this N = len(np.lib.format.MAGIC_PREFIX) - magic = data.read(N) + magic = model.get_data().read(N) # If the file size is less than N, we need to make sure not # to seek past the beginning of the file - data.seek(-min(N, len(magic)), 1) # back-up + model.get_data().seek(-min(N, len(magic)), 1) # back-up if magic.startswith(_ZIP_PREFIX) or magic.startswith(_ZIP_SUFFIX): # .npz file return ScanResults( @@ -206,40 +201,38 @@ def scan_numpy( scan_name, SkipCategories.NOT_IMPLEMENTED, "Scanning of .npz files is not implemented yet", - str(source), + str(model.get_source()), ) ], ) elif magic == np.lib.format.MAGIC_PREFIX: # .npy file - version = np.lib.format.read_magic(data) # type: ignore[no-untyped-call] + version = np.lib.format.read_magic(model.get_data()) # type: ignore[no-untyped-call] np.lib.format._check_version(version) # type: ignore[attr-defined] - _, _, dtype = np.lib.format._read_array_header(data, version) # type: ignore[attr-defined] + _, _, dtype = np.lib.format._read_array_header(model.get_data(), version) # type: ignore[attr-defined] if dtype.hasobject: - return scan_pickle_bytes(data, source, settings, scan_name) + return scan_pickle_bytes(model, settings, scan_name) else: return ScanResults([], [], []) else: - return scan_pickle_bytes(data, source, settings, scan_name) + return scan_pickle_bytes(model, settings, scan_name) -def scan_pytorch( - data: IO[bytes], source: Union[str, Path], settings: Dict[str, Any] -) -> ScanResults: +def scan_pytorch(model: Model, settings: Dict[str, Any]) -> ScanResults: scan_name = "pytorch" - should_read_directly = _should_read_directly(data) - if should_read_directly and data.tell() == 0: + should_read_directly = _should_read_directly(model.get_data()) + if should_read_directly and model.get_data().tell() == 0: # try loading from tar try: # TODO: implement loading from tar raise TarError() except TarError: # file does not contain a tar - data.seek(0) + model.get_data().seek(0) - magic = get_magic_number(data) + magic = get_magic_number(model.get_data()) if magic != MAGIC_NUMBER: return ScanResults( [], @@ -249,8 +242,9 @@ def scan_pytorch( scan_name, SkipCategories.MAGIC_NUMBER, f"Invalid magic number", - str(source), + str(model.get_source()), ) ], ) - return scan_pickle_bytes(data, source, settings, scan_name, multiple_pickles=False) + + return scan_pickle_bytes(model, settings, scan_name, multiple_pickles=False) diff --git a/tests/test_modelscan.py b/tests/test_modelscan.py index cde5c63..e8bfaed 100644 --- a/tests/test_modelscan.py +++ b/tests/test_modelscan.py @@ -34,11 +34,11 @@ ) from modelscan.tools.picklescanner import ( scan_pickle_bytes, - scan_numpy, ) from modelscan.skip import SkipCategories from modelscan.settings import DEFAULT_SETTINGS +from modelscan.model import Model settings: Dict[str, Any] = DEFAULT_SETTINGS @@ -431,12 +431,8 @@ def test_scan_pickle_bytes() -> None: ) ] - assert ( - scan_pickle_bytes( - io.BytesIO(pickle.dumps(Malicious1())), "file.pkl", settings - ).issues - == expected - ) + model = Model("file.pkl", io.BytesIO(pickle.dumps(Malicious1()))) + assert scan_pickle_bytes(model, settings).issues == expected def test_scan_zip(zip_file_path: Any) -> None: