From babe843365f472b534681dc32cda0a902368bfe4 Mon Sep 17 00:00:00 2001 From: Oleksandr Yaremchuk Date: Thu, 21 Mar 2024 16:54:20 +0100 Subject: [PATCH] * address optional feedback --- modelscan/modelscan.py | 12 +++---- modelscan/scanners/h5/scan.py | 4 +-- modelscan/scanners/keras/scan.py | 60 +++++++++++++++----------------- tests/test_modelscan.py | 24 ++++++++----- 4 files changed, 50 insertions(+), 50 deletions(-) diff --git a/modelscan/modelscan.py b/modelscan/modelscan.py index 2745990..a8faa86 100644 --- a/modelscan/modelscan.py +++ b/modelscan/modelscan.py @@ -161,9 +161,10 @@ def scan( if self._skipped: all_skipped_paths = [skipped.source for skipped in self._skipped] - for skipped in self._skipped: - main_file_path = skipped.source.split(":")[0] - if main_file_path == skipped.source: + for path in all_paths: + main_file_path = str(path).split(":")[0] + + if main_file_path == str(path): continue # If main container is skipped, we only add its content to skipped but not the file itself @@ -174,11 +175,6 @@ def scan( continue - # If main container is scanned, we consider all files to be scanned - self._skipped = [ - item for item in self._skipped if item.source != skipped.source - ] - return self._generate_results() def _scan_source( diff --git a/modelscan/scanners/h5/scan.py b/modelscan/scanners/h5/scan.py index d8452e6..3e81cb0 100644 --- a/modelscan/scanners/h5/scan.py +++ b/modelscan/scanners/h5/scan.py @@ -90,7 +90,7 @@ def _scan_keras_h5_file(self, model: Model) -> Optional[ScanResults]: ) def _check_model_config(self, model: Model) -> bool: - with h5py.File(model.get_stream(), "r") as model_hdf5: + with h5py.File(model.get_stream()) as model_hdf5: if "model_config" in model_hdf5.attrs.keys(): return True else: @@ -100,7 +100,7 @@ def _check_model_config(self, model: Model) -> bool: def _get_keras_h5_operator_names(self, model: Model) -> Optional[List[Any]]: # Todo: source isn't guaranteed to be a file - with h5py.File(model.get_stream(), "r") as model_hdf5: + with h5py.File(model.get_stream()) as model_hdf5: try: if not "model_config" in model_hdf5.attrs.keys(): return None diff --git a/modelscan/scanners/keras/scan.py b/modelscan/scanners/keras/scan.py index 08f897e..8e6eb12 100644 --- a/modelscan/scanners/keras/scan.py +++ b/modelscan/scanners/keras/scan.py @@ -79,22 +79,27 @@ def _scan_keras_config_file(self, model: Model) -> ScanResults: # if self._check_json_data(source, config_file): - operators_in_model = self._get_keras_operator_names(model) - if operators_in_model: - if "JSONDecodeError" in operators_in_model: - return ScanResults( - [], - [ - ModelScanError( - self.name(), - ErrorCategories.JSON_DECODE, - f"Not a valid JSON data", - str(model.get_source()), - ) - ], - [], - ) + try: + operators_in_model = self._get_keras_operator_names(model) + except json.JSONDecodeError as e: + logger.error( + f"Not a valid JSON data from source: {model.get_source()}, error: {e}" + ) + return ScanResults( + [], + [ + ModelScanError( + self.name(), + ErrorCategories.JSON_DECODE, + f"Not a valid JSON data", + str(model.get_source()), + ) + ], + [], + ) + + if operators_in_model: return KerasLambdaDetectScan._check_for_unsafe_tf_keras_operator( module_name=machine_learning_library_name, raw_operator=operators_in_model, @@ -112,22 +117,15 @@ def _scan_keras_config_file(self, model: Model) -> ScanResults: ) def _get_keras_operator_names(self, model: Model) -> List[str]: - try: - model_config_data = json.load(model.get_stream()) - - lambda_layers = [ - layer.get("config", {}).get("function", {}) - for layer in model_config_data.get("config", {}).get("layers", {}) - if layer.get("class_name", {}) == "Lambda" - ] - if lambda_layers: - return ["Lambda"] * len(lambda_layers) - - except json.JSONDecodeError as e: - logger.error( - f"Not a valid JSON data from source: {model.get_source()}, error: {e}" - ) - return ["JSONDecodeError"] + model_config_data = json.load(model.get_stream()) + + lambda_layers = [ + layer.get("config", {}).get("function", {}) + for layer in model_config_data.get("config", {}).get("layers", {}) + if layer.get("class_name", {}) == "Lambda" + ] + if lambda_layers: + return ["Lambda"] * len(lambda_layers) return [] diff --git a/tests/test_modelscan.py b/tests/test_modelscan.py index cc9963d..890a705 100644 --- a/tests/test_modelscan.py +++ b/tests/test_modelscan.py @@ -452,10 +452,7 @@ def test_scan_zip(zip_file_path: Any) -> None: ms = ModelScan() results = ms.scan(f"{zip_file_path}/test.zip") assert results["summary"]["scanned"]["scanned_files"] == [f"test.zip:data.pkl"] - assert [ - skipped_file["source"] - for skipped_file in results["summary"]["skipped"]["skipped_files"] - ] == ["test.zip"] + assert results["summary"]["skipped"]["skipped_files"] == [] assert ms.issues.all_issues == expected @@ -1242,10 +1239,7 @@ def test_scan_directory_path(file_path: str) -> None: f"benign0_v3.dill", f"benign0_v4.dill", } - assert [ - skipped_file["source"] - for skipped_file in results["summary"]["skipped"]["skipped_files"] - ] == ["malicious1.zip"] + assert results["summary"]["skipped"]["skipped_files"] == [] assert results["errors"] == [] @@ -1293,7 +1287,19 @@ def test_scan_keras(keras_file_path: Any, file_extension: str) -> None: f"safe{file_extension}" ] - assert results["summary"]["skipped"]["skipped_files"] == [] + if file_extension == ".keras": + assert set( + [ + skipped_file["source"] + for skipped_file in results["summary"]["skipped"]["skipped_files"] + ] + ) == { + f"safe{file_extension}:metadata.json", + f"safe{file_extension}:config.json", + f"safe{file_extension}:model.weights.h5", + } + else: + assert results["summary"]["skipped"]["skipped_files"] == [] assert results["errors"] == []