From 8cfb1b1c678077aadf616d417ac8197de36028bc Mon Sep 17 00:00:00 2001 From: Sam Washko Date: Thu, 1 Feb 2024 12:33:12 -0800 Subject: [PATCH 1/2] skip h5 with no config --- modelscan/scanners/h5/scan.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/modelscan/scanners/h5/scan.py b/modelscan/scanners/h5/scan.py index afb5b2d..8621fec 100644 --- a/modelscan/scanners/h5/scan.py +++ b/modelscan/scanners/h5/scan.py @@ -31,6 +31,10 @@ def scan( ): return None + with h5py.File(source, "r") as model_hdf5: + if not hasattr(model_hdf5, "model_config"): + return None # skip file if there is no model_config + dep_error = self.handle_binary_dependencies() if dep_error: return ScanResults([], [dep_error]) From c5979c7b437a4ada1f6988c4125d3c383d6a92ea Mon Sep 17 00:00:00 2001 From: Sam Washko Date: Thu, 1 Feb 2024 14:36:21 -0800 Subject: [PATCH 2/2] only open file once --- modelscan/scanners/h5/scan.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/modelscan/scanners/h5/scan.py b/modelscan/scanners/h5/scan.py index 8621fec..a5f976c 100644 --- a/modelscan/scanners/h5/scan.py +++ b/modelscan/scanners/h5/scan.py @@ -31,10 +31,6 @@ def scan( ): return None - with h5py.File(source, "r") as model_hdf5: - if not hasattr(model_hdf5, "model_config"): - return None # skip file if there is no model_config - dep_error = self.handle_binary_dependencies() if dep_error: return ScanResults([], [dep_error]) @@ -45,11 +41,17 @@ def scan( ) return None - return self.label_results(self._scan_keras_h5_file(source)) + results = self._scan_keras_h5_file(source) + if results: + return self.label_results(results) + else: + return None - def _scan_keras_h5_file(self, source: Union[str, Path]) -> ScanResults: + def _scan_keras_h5_file(self, source: Union[str, Path]) -> Optional[ScanResults]: machine_learning_library_name = "Keras" operators_in_model = self._get_keras_h5_operator_names(source) + if not operators_in_model: + return None return H5LambdaDetectScan._check_for_unsafe_tf_keras_operator( module_name=machine_learning_library_name, raw_operator=operators_in_model, @@ -59,11 +61,15 @@ def _scan_keras_h5_file(self, source: Union[str, Path]) -> ScanResults: ]["unsafe_keras_operators"], ) - def _get_keras_h5_operator_names(self, source: Union[str, Path]) -> List[str]: + def _get_keras_h5_operator_names( + self, source: Union[str, Path] + ) -> Optional[List[str]]: # Todo: source isn't guaranteed to be a file with h5py.File(source, "r") as model_hdf5: try: + if not "model_config" in model_hdf5.attrs.keys(): + return None model_config = json.loads(model_hdf5.attrs.get("model_config", {})) layers = model_config.get("config", {}).get("layers", {}) lambda_layers = []