diff --git a/modelscan/scanners/__init__.py b/modelscan/scanners/__init__.py index 2a45017..fc09d38 100644 --- a/modelscan/scanners/__init__.py +++ b/modelscan/scanners/__init__.py @@ -1,8 +1,12 @@ -from modelscan.scanners.h5.scan import H5Scan +from modelscan.scanners.h5.scan import H5LambdaDetectScan from modelscan.scanners.pickle.scan import ( - PickleScan, - NumpyScan, - PyTorchScan, + PickleUnsafeOpScan, + NumpyUnsafeOpScan, + PyTorchUnsafeOpScan, ) -from modelscan.scanners.saved_model.scan import SavedModelScan -from modelscan.scanners.keras.scan import KerasScan +from modelscan.scanners.saved_model.scan import ( + SavedModelScan, + SavedModelLambdaDetectScan, + SavedModelTensorflowOpScan, +) +from modelscan.scanners.keras.scan import KerasLambdaDetectScan diff --git a/modelscan/scanners/h5/scan.py b/modelscan/scanners/h5/scan.py index 8855ed7..afb5b2d 100644 --- a/modelscan/scanners/h5/scan.py +++ b/modelscan/scanners/h5/scan.py @@ -12,12 +12,12 @@ from modelscan.error import ModelScanError from modelscan.scanners.scan import ScanResults -from modelscan.scanners.saved_model.scan import SavedModelScan +from modelscan.scanners.saved_model.scan import SavedModelLambdaDetectScan logger = logging.getLogger("modelscan") -class H5Scan(SavedModelScan): +class H5LambdaDetectScan(SavedModelLambdaDetectScan): def scan( self, source: Union[str, Path], @@ -25,7 +25,9 @@ def scan( ) -> Optional[ScanResults]: if ( not Path(source).suffix - in self._settings["scanners"][H5Scan.full_name()]["supported_extensions"] + in self._settings["scanners"][H5LambdaDetectScan.full_name()][ + "supported_extensions" + ] ): return None @@ -44,11 +46,13 @@ def scan( def _scan_keras_h5_file(self, source: Union[str, Path]) -> ScanResults: machine_learning_library_name = "Keras" operators_in_model = self._get_keras_h5_operator_names(source) - return H5Scan._check_for_unsafe_tf_keras_operator( + return H5LambdaDetectScan._check_for_unsafe_tf_keras_operator( module_name=machine_learning_library_name, raw_operator=operators_in_model, source=source, - settings=self._settings, + unsafe_operators=self._settings["scanners"][ + SavedModelLambdaDetectScan.full_name() + ]["unsafe_keras_operators"], ) def _get_keras_h5_operator_names(self, source: Union[str, Path]) -> List[str]: @@ -73,21 +77,20 @@ def _get_keras_h5_operator_names(self, source: Union[str, Path]) -> List[str]: return [] - @staticmethod - def name() -> str: - return "hdf5" - - @staticmethod - def full_name() -> str: - return "modelscan.scanners.H5Scan" - - @staticmethod def handle_binary_dependencies( - settings: Optional[Dict[str, Any]] = None + self, settings: Optional[Dict[str, Any]] = None ) -> Optional[ModelScanError]: if not h5py_installed: return ModelScanError( - SavedModelScan.name(), - f"To use {H5Scan.full_name()}, please install modelscan with h5py extras. 'pip install \"modelscan\[h5py]\"' if you are using pip.", + H5LambdaDetectScan.name(), + f"To use {H5LambdaDetectScan.full_name()}, please install modelscan with h5py extras. 'pip install \"modelscan\[h5py]\"' if you are using pip.", ) return None + + @staticmethod + def name() -> str: + return "hdf5" + + @staticmethod + def full_name() -> str: + return "modelscan.scanners.H5LambdaDetectScan" diff --git a/modelscan/scanners/keras/scan.py b/modelscan/scanners/keras/scan.py index f2ece66..74323ac 100644 --- a/modelscan/scanners/keras/scan.py +++ b/modelscan/scanners/keras/scan.py @@ -7,13 +7,13 @@ from modelscan.error import ModelScanError from modelscan.scanners.scan import ScanResults -from modelscan.scanners.saved_model.scan import SavedModelScan +from modelscan.scanners.saved_model.scan import SavedModelLambdaDetectScan logger = logging.getLogger("modelscan") -class KerasScan(SavedModelScan): +class KerasLambdaDetectScan(SavedModelLambdaDetectScan): def scan( self, source: Union[str, Path], @@ -21,7 +21,9 @@ def scan( ) -> Optional[ScanResults]: if ( not Path(source).suffix - in self._settings["scanners"][KerasScan.full_name()]["supported_extensions"] + in self._settings["scanners"][KerasLambdaDetectScan.full_name()][ + "supported_extensions" + ] ): return None @@ -46,7 +48,7 @@ def scan( [], [ ModelScanError( - KerasScan.name(), + KerasLambdaDetectScan.name(), f"Skipping zip file {source}, due to error: {e}", ) ], @@ -57,7 +59,7 @@ def scan( [], [ ModelScanError( - KerasScan.name(), + KerasLambdaDetectScan.name(), f"Unable to scan .keras file", # Not sure if this is a representative message for ModelScanError ) ], @@ -68,11 +70,13 @@ def _scan_keras_config_file( ) -> ScanResults: machine_learning_library_name = "Keras" operators_in_model = self._get_keras_operator_names(source, config_file) - return KerasScan._check_for_unsafe_tf_keras_operator( + return KerasLambdaDetectScan._check_for_unsafe_tf_keras_operator( module_name=machine_learning_library_name, raw_operator=operators_in_model, source=source, - settings=self._settings, + unsafe_operators=self._settings["scanners"][ + SavedModelLambdaDetectScan.full_name() + ]["unsafe_keras_operators"], ) def _get_keras_operator_names( @@ -101,4 +105,4 @@ def name() -> str: @staticmethod def full_name() -> str: - return "modelscan.scanners.KerasScan" + return "modelscan.scanners.KerasLambdaDetectScan" diff --git a/modelscan/scanners/pickle/scan.py b/modelscan/scanners/pickle/scan.py index 1f2fc6c..2a84399 100644 --- a/modelscan/scanners/pickle/scan.py +++ b/modelscan/scanners/pickle/scan.py @@ -13,7 +13,7 @@ logger = logging.getLogger("modelscan") -class PyTorchScan(ScanBase): +class PyTorchUnsafeOpScan(ScanBase): def scan( self, source: Union[str, Path], @@ -21,7 +21,7 @@ def scan( ) -> Optional[ScanResults]: if ( not Path(source).suffix - in self._settings["scanners"][PyTorchScan.full_name()][ + in self._settings["scanners"][PyTorchUnsafeOpScan.full_name()][ "supported_extensions" ] ): @@ -47,10 +47,10 @@ def name() -> str: @staticmethod def full_name() -> str: - return "modelscan.scanners.PyTorchScan" + return "modelscan.scanners.PyTorchUnsafeOpScan" -class NumpyScan(ScanBase): +class NumpyUnsafeOpScan(ScanBase): def scan( self, source: Union[str, Path], @@ -58,7 +58,9 @@ def scan( ) -> Optional[ScanResults]: if ( not Path(source).suffix - in self._settings["scanners"][NumpyScan.full_name()]["supported_extensions"] + in self._settings["scanners"][NumpyUnsafeOpScan.full_name()][ + "supported_extensions" + ] ): return None @@ -76,10 +78,10 @@ def name() -> str: @staticmethod def full_name() -> str: - return "modelscan.scanners.NumpyScan" + return "modelscan.scanners.NumpyUnsafeOpScan" -class PickleScan(ScanBase): +class PickleUnsafeOpScan(ScanBase): def scan( self, source: Union[str, Path], @@ -87,7 +89,7 @@ def scan( ) -> Optional[ScanResults]: if ( not Path(source).suffix - in self._settings["scanners"][PickleScan.full_name()][ + in self._settings["scanners"][PickleUnsafeOpScan.full_name()][ "supported_extensions" ] ): @@ -112,4 +114,4 @@ def name() -> str: @staticmethod def full_name() -> str: - return "modelscan.scanners.PickleScan" + return "modelscan.scanners.PickleUnsafeOpScan" diff --git a/modelscan/scanners/saved_model/scan.py b/modelscan/scanners/saved_model/scan.py index 0b2dd4c..6eecbcd 100644 --- a/modelscan/scanners/saved_model/scan.py +++ b/modelscan/scanners/saved_model/scan.py @@ -31,9 +31,7 @@ def scan( ) -> Optional[ScanResults]: if ( not Path(source).suffix - in self._settings["scanners"][SavedModelScan.full_name()][ - "supported_extensions" - ] + in self._settings["scanners"][self.full_name()]["supported_extensions"] ): return None @@ -48,11 +46,69 @@ def scan( with open(source, "rb") as file_io: results = self._scan(source, data=file_io) - return self.label_results(results) + if results: + return self.label_results(results) + else: + return None + + def _scan(self, source: Union[str, Path], data: IO[bytes]) -> Optional[ScanResults]: + raise NotImplementedError + + # This function checks for malicious operators in both Keras and Tensorflow + @staticmethod + def _check_for_unsafe_tf_keras_operator( + module_name: str, + raw_operator: List[str], + source: Union[str, Path], + unsafe_operators: Dict[str, Any], + ) -> ScanResults: + issues: List[Issue] = [] + all_operators = tensorflow.raw_ops.__dict__.keys() + all_safe_operators = [ + operator for operator in list(all_operators) if operator[0] != "_" + ] + + for op in raw_operator: + if op in unsafe_operators: + severity = IssueSeverity[unsafe_operators[op]] + elif op not in all_safe_operators: + severity = IssueSeverity.MEDIUM + else: + continue + + issues.append( + Issue( + code=IssueCode.UNSAFE_OPERATOR, + severity=severity, + details=OperatorIssueDetails( + module=module_name, operator=op, source=source + ), + ) + ) + return ScanResults(issues, []) + + def handle_binary_dependencies( + self, settings: Optional[Dict[str, Any]] = None + ) -> Optional[ModelScanError]: + if not tensorflow_installed: + return ModelScanError( + self.name(), + f"To use {self.full_name()}, please install modelscan with tensorflow extras. 'pip install \"modelscan\[tensorflow]\"' if you are using pip.", + ) + return None - def _scan(self, source: Union[str, Path], data: IO[bytes]) -> ScanResults: + @staticmethod + def name() -> str: + return "saved_model" + + @staticmethod + def full_name() -> str: + return "modelscan.scanners.SavedModelScan" + + +class SavedModelLambdaDetectScan(SavedModelScan): + def _scan(self, source: Union[str, Path], data: IO[bytes]) -> Optional[ScanResults]: file_name = str(source).split("/")[-1] - # Default is a tensorflow model file if file_name == "keras_metadata.pb": machine_learning_library_name = "Keras" operators_in_model = self._get_keras_pb_operator_names( @@ -60,11 +116,13 @@ def _scan(self, source: Union[str, Path], data: IO[bytes]) -> ScanResults: ) else: - machine_learning_library_name = "Tensorflow" - operators_in_model = self._get_tensorflow_operator_names(data=data) + return None return SavedModelScan._check_for_unsafe_tf_keras_operator( - machine_learning_library_name, operators_in_model, source, self._settings + machine_learning_library_name, + operators_in_model, + source, + self._settings["scanners"][self.full_name()]["unsafe_keras_operators"], ) @staticmethod @@ -93,6 +151,28 @@ def _get_keras_pb_operator_names( return [] + @staticmethod + def full_name() -> str: + return "modelscan.scanners.SavedModelLambdaDetectScan" + + +class SavedModelTensorflowOpScan(SavedModelScan): + def _scan(self, source: Union[str, Path], data: IO[bytes]) -> Optional[ScanResults]: + file_name = str(source).split("/")[-1] + if file_name == "keras_metadata.pb": + return None + + else: + machine_learning_library_name = "Tensorflow" + operators_in_model = self._get_tensorflow_operator_names(data=data) + + return SavedModelScan._check_for_unsafe_tf_keras_operator( + machine_learning_library_name, + operators_in_model, + source, + self._settings["scanners"][self.full_name()]["unsafe_tf_operators"], + ) + def _get_tensorflow_operator_names(self, data: IO[bytes]) -> List[str]: saved_model = SavedModel() saved_model.ParseFromString(data.read()) @@ -109,58 +189,6 @@ def _get_tensorflow_operator_names(self, data: IO[bytes]) -> List[str]: # Sort and convert to list return list(sorted(model_op_names)) - # This function checks for malicious operators in both Keras and Tensorflow - @staticmethod - def _check_for_unsafe_tf_keras_operator( - module_name: str, - raw_operator: List[str], - source: Union[str, Path], - settings: Dict[str, Any], - ) -> ScanResults: - unsafe_operators: Dict[str, Any] = settings["scanners"][ - SavedModelScan.full_name() - ]["unsafe_tf_keras_operators"] - - issues: List[Issue] = [] - all_operators = tensorflow.raw_ops.__dict__.keys() - all_safe_operators = [ - operator for operator in list(all_operators) if operator[0] != "_" - ] - - for op in raw_operator: - if op in unsafe_operators: - severity = IssueSeverity[unsafe_operators[op]] - elif op not in all_safe_operators: - severity = IssueSeverity.MEDIUM - else: - continue - - issues.append( - Issue( - code=IssueCode.UNSAFE_OPERATOR, - severity=severity, - details=OperatorIssueDetails( - module=module_name, operator=op, source=source - ), - ) - ) - return ScanResults(issues, []) - - @staticmethod - def name() -> str: - return "saved_model" - @staticmethod def full_name() -> str: - return "modelscan.scanners.SavedModelScan" - - @staticmethod - def handle_binary_dependencies( - settings: Optional[Dict[str, Any]] = None - ) -> Optional[ModelScanError]: - if not tensorflow_installed: - return ModelScanError( - SavedModelScan.name(), - f"To use {SavedModelScan.full_name()}, please install modelscan with tensorflow extras. 'pip install \"modelscan\[tensorflow]\"' if you are using pip.", - ) - return None + return "modelscan.scanners.SavedModelTensorflowOpScan" diff --git a/modelscan/scanners/scan.py b/modelscan/scanners/scan.py index 47aaf40..681c87c 100644 --- a/modelscan/scanners/scan.py +++ b/modelscan/scanners/scan.py @@ -40,9 +40,8 @@ def scan( ) -> Optional[ScanResults]: raise NotImplementedError - @staticmethod def handle_binary_dependencies( - settings: Optional[Dict[str, Any]] = None + self, settings: Optional[Dict[str, Any]] = None ) -> Optional[ModelScanError]: """ Implement this method if the plugin requires a binary dependency. diff --git a/modelscan/settings.py b/modelscan/settings.py index c4bad2e..8041ffe 100644 --- a/modelscan/settings.py +++ b/modelscan/settings.py @@ -2,36 +2,45 @@ from typing import Any +from modelscan._version import __version__ + DEFAULT_REPORTING_MODULES = { "console": "modelscan.reports.ConsoleReport", "json": "modelscan.reports.JSONReport", } DEFAULT_SETTINGS = { + "modelscan_version": __version__, "supported_zip_extensions": [".zip", ".npz"], "scanners": { - "modelscan.scanners.H5Scan": { + "modelscan.scanners.H5LambdaDetectScan": { "enabled": True, "supported_extensions": [".h5"], }, - "modelscan.scanners.KerasScan": { + "modelscan.scanners.KerasLambdaDetectScan": { "enabled": True, "supported_extensions": [".keras"], }, - "modelscan.scanners.SavedModelScan": { + "modelscan.scanners.SavedModelLambdaDetectScan": { "enabled": True, "supported_extensions": [".pb"], - "unsafe_tf_keras_operators": { + "unsafe_keras_operators": { + "Lambda": "MEDIUM", + }, + }, + "modelscan.scanners.SavedModelTensorflowOpScan": { + "enabled": True, + "supported_extensions": [".pb"], + "unsafe_tf_operators": { "ReadFile": "HIGH", "WriteFile": "HIGH", - "Lambda": "MEDIUM", }, }, - "modelscan.scanners.NumpyScan": { + "modelscan.scanners.NumpyUnsafeOpScan": { "enabled": True, "supported_extensions": [".npy"], }, - "modelscan.scanners.PickleScan": { + "modelscan.scanners.PickleUnsafeOpScan": { "enabled": True, "supported_extensions": [ ".pkl", @@ -42,7 +51,7 @@ ".data", ], }, - "modelscan.scanners.PyTorchScan": { + "modelscan.scanners.PyTorchUnsafeOpScan": { "enabled": True, "supported_extensions": [".bin", ".pt", ".pth", ".ckpt"], },