From 1e78ba4350b0aa848558b166fdbd4e3687d67f7e Mon Sep 17 00:00:00 2001 From: Sam Washko Date: Wed, 17 Jan 2024 11:27:25 -0800 Subject: [PATCH] make handle_binary_dependencies class method --- modelscan/scanners/h5/scan.py | 19 +++++++++---------- modelscan/scanners/saved_model/scan.py | 21 ++++++++++----------- modelscan/scanners/scan.py | 3 +-- 3 files changed, 20 insertions(+), 23 deletions(-) diff --git a/modelscan/scanners/h5/scan.py b/modelscan/scanners/h5/scan.py index b6a87ea..afb5b2d 100644 --- a/modelscan/scanners/h5/scan.py +++ b/modelscan/scanners/h5/scan.py @@ -77,17 +77,8 @@ def _get_keras_h5_operator_names(self, source: Union[str, Path]) -> List[str]: return [] - @staticmethod - def name() -> str: - return "hdf5" - - @staticmethod - def full_name() -> str: - return "modelscan.scanners.H5LambdaDetectScan" - - @staticmethod def handle_binary_dependencies( - settings: Optional[Dict[str, Any]] = None + self, settings: Optional[Dict[str, Any]] = None ) -> Optional[ModelScanError]: if not h5py_installed: return ModelScanError( @@ -95,3 +86,11 @@ def handle_binary_dependencies( f"To use {H5LambdaDetectScan.full_name()}, please install modelscan with h5py extras. 'pip install \"modelscan\[h5py]\"' if you are using pip.", ) return None + + @staticmethod + def name() -> str: + return "hdf5" + + @staticmethod + def full_name() -> str: + return "modelscan.scanners.H5LambdaDetectScan" diff --git a/modelscan/scanners/saved_model/scan.py b/modelscan/scanners/saved_model/scan.py index 30e2396..6eecbcd 100644 --- a/modelscan/scanners/saved_model/scan.py +++ b/modelscan/scanners/saved_model/scan.py @@ -87,6 +87,16 @@ def _check_for_unsafe_tf_keras_operator( ) return ScanResults(issues, []) + def handle_binary_dependencies( + self, settings: Optional[Dict[str, Any]] = None + ) -> Optional[ModelScanError]: + if not tensorflow_installed: + return ModelScanError( + self.name(), + f"To use {self.full_name()}, please install modelscan with tensorflow extras. 'pip install \"modelscan\[tensorflow]\"' if you are using pip.", + ) + return None + @staticmethod def name() -> str: return "saved_model" @@ -95,17 +105,6 @@ def name() -> str: def full_name() -> str: return "modelscan.scanners.SavedModelScan" - @staticmethod - def handle_binary_dependencies( - settings: Optional[Dict[str, Any]] = None - ) -> Optional[ModelScanError]: - if not tensorflow_installed: - return ModelScanError( - SavedModelScan.name(), - f"To use {SavedModelScan.full_name()}, please install modelscan with tensorflow extras. 'pip install \"modelscan\[tensorflow]\"' if you are using pip.", - ) - return None - class SavedModelLambdaDetectScan(SavedModelScan): def _scan(self, source: Union[str, Path], data: IO[bytes]) -> Optional[ScanResults]: diff --git a/modelscan/scanners/scan.py b/modelscan/scanners/scan.py index 47aaf40..681c87c 100644 --- a/modelscan/scanners/scan.py +++ b/modelscan/scanners/scan.py @@ -40,9 +40,8 @@ def scan( ) -> Optional[ScanResults]: raise NotImplementedError - @staticmethod def handle_binary_dependencies( - settings: Optional[Dict[str, Any]] = None + self, settings: Optional[Dict[str, Any]] = None ) -> Optional[ModelScanError]: """ Implement this method if the plugin requires a binary dependency.