Skip to content

Commit

Permalink
Follow-up for Model object introduction (#121)
Browse files Browse the repository at this point in the history
  • Loading branch information
asofter authored Mar 22, 2024
1 parent 5f1818b commit f49d8ac
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 50 deletions.
12 changes: 4 additions & 8 deletions modelscan/modelscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,10 @@ def scan(

if self._skipped:
all_skipped_paths = [skipped.source for skipped in self._skipped]
for skipped in self._skipped:
main_file_path = skipped.source.split(":")[0]
if main_file_path == skipped.source:
for path in all_paths:
main_file_path = str(path).split(":")[0]

if main_file_path == str(path):
continue

# If main container is skipped, we only add its content to skipped but not the file itself
Expand All @@ -174,11 +175,6 @@ def scan(

continue

# If main container is scanned, we consider all files to be scanned
self._skipped = [
item for item in self._skipped if item.source != skipped.source
]

return self._generate_results()

def _scan_source(
Expand Down
4 changes: 2 additions & 2 deletions modelscan/scanners/h5/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def _scan_keras_h5_file(self, model: Model) -> Optional[ScanResults]:
)

def _check_model_config(self, model: Model) -> bool:
with h5py.File(model.get_stream(), "r") as model_hdf5:
with h5py.File(model.get_stream()) as model_hdf5:
if "model_config" in model_hdf5.attrs.keys():
return True
else:
Expand All @@ -100,7 +100,7 @@ def _check_model_config(self, model: Model) -> bool:
def _get_keras_h5_operator_names(self, model: Model) -> Optional[List[Any]]:
# Todo: source isn't guaranteed to be a file

with h5py.File(model.get_stream(), "r") as model_hdf5:
with h5py.File(model.get_stream()) as model_hdf5:
try:
if not "model_config" in model_hdf5.attrs.keys():
return None
Expand Down
60 changes: 29 additions & 31 deletions modelscan/scanners/keras/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,22 +79,27 @@ def _scan_keras_config_file(self, model: Model) -> ScanResults:

# if self._check_json_data(source, config_file):

operators_in_model = self._get_keras_operator_names(model)
if operators_in_model:
if "JSONDecodeError" in operators_in_model:
return ScanResults(
[],
[
ModelScanError(
self.name(),
ErrorCategories.JSON_DECODE,
f"Not a valid JSON data",
str(model.get_source()),
)
],
[],
)
try:
operators_in_model = self._get_keras_operator_names(model)
except json.JSONDecodeError as e:
logger.error(
f"Not a valid JSON data from source: {model.get_source()}, error: {e}"
)

return ScanResults(
[],
[
ModelScanError(
self.name(),
ErrorCategories.JSON_DECODE,
f"Not a valid JSON data",
str(model.get_source()),
)
],
[],
)

if operators_in_model:
return KerasLambdaDetectScan._check_for_unsafe_tf_keras_operator(
module_name=machine_learning_library_name,
raw_operator=operators_in_model,
Expand All @@ -112,22 +117,15 @@ def _scan_keras_config_file(self, model: Model) -> ScanResults:
)

def _get_keras_operator_names(self, model: Model) -> List[str]:
try:
model_config_data = json.load(model.get_stream())

lambda_layers = [
layer.get("config", {}).get("function", {})
for layer in model_config_data.get("config", {}).get("layers", {})
if layer.get("class_name", {}) == "Lambda"
]
if lambda_layers:
return ["Lambda"] * len(lambda_layers)

except json.JSONDecodeError as e:
logger.error(
f"Not a valid JSON data from source: {model.get_source()}, error: {e}"
)
return ["JSONDecodeError"]
model_config_data = json.load(model.get_stream())

lambda_layers = [
layer.get("config", {}).get("function", {})
for layer in model_config_data.get("config", {}).get("layers", {})
if layer.get("class_name", {}) == "Lambda"
]
if lambda_layers:
return ["Lambda"] * len(lambda_layers)

return []

Expand Down
24 changes: 15 additions & 9 deletions tests/test_modelscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,10 +452,7 @@ def test_scan_zip(zip_file_path: Any) -> None:
ms = ModelScan()
results = ms.scan(f"{zip_file_path}/test.zip")
assert results["summary"]["scanned"]["scanned_files"] == [f"test.zip:data.pkl"]
assert [
skipped_file["source"]
for skipped_file in results["summary"]["skipped"]["skipped_files"]
] == ["test.zip"]
assert results["summary"]["skipped"]["skipped_files"] == []
assert ms.issues.all_issues == expected


Expand Down Expand Up @@ -1242,10 +1239,7 @@ def test_scan_directory_path(file_path: str) -> None:
f"benign0_v3.dill",
f"benign0_v4.dill",
}
assert [
skipped_file["source"]
for skipped_file in results["summary"]["skipped"]["skipped_files"]
] == ["malicious1.zip"]
assert results["summary"]["skipped"]["skipped_files"] == []
assert results["errors"] == []


Expand Down Expand Up @@ -1293,7 +1287,19 @@ def test_scan_keras(keras_file_path: Any, file_extension: str) -> None:
f"safe{file_extension}"
]

assert results["summary"]["skipped"]["skipped_files"] == []
if file_extension == ".keras":
assert set(
[
skipped_file["source"]
for skipped_file in results["summary"]["skipped"]["skipped_files"]
]
) == {
f"safe{file_extension}:metadata.json",
f"safe{file_extension}:config.json",
f"safe{file_extension}:model.weights.h5",
}
else:
assert results["summary"]["skipped"]["skipped_files"] == []

assert results["errors"] == []

Expand Down

0 comments on commit f49d8ac

Please sign in to comment.