Skip to content

Commit

Permalink
Detect multiple Lambda layers (#61)
Browse files Browse the repository at this point in the history
  • Loading branch information
mehrinkiani authored Dec 1, 2023
1 parent 572df37 commit 07958d0
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 34 deletions.
24 changes: 16 additions & 8 deletions modelscan/models/h5/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,23 @@ def _scan_keras_h5_file(
def _get_keras_h5_operator_names(source: Union[str, Path]) -> List[str]:
# Todo: source isn't guaranteed to be a file
with h5py.File(source, "r") as model_hdf5:
lambda_code = [
layer.get("config", {}).get("function", {})
for layer in json.loads(model_hdf5.attrs["model_config"])["config"][
"layers"
]
if layer["class_name"] == "Lambda"
]
try:
model_config = json.loads(model_hdf5.attrs.get("model_config", {}))
layers = model_config.get("config", {}).get("layers", {})
lambda_layers = []
for layer in layers:
if layer.get("class_name", {}) == "Lambda":
lambda_layers.append(
layer.get("config", {}).get("function", {})
)
except json.JSONDecodeError as e:
logger.error(f"Not a valid JSON data from source: {source}, error: {e}")
return []

if lambda_layers:
return ["Lambda"] * len(lambda_layers)

return ["Lambda"] if lambda_code else []
return []

@staticmethod
def supported_extensions() -> List[str]:
Expand Down
11 changes: 7 additions & 4 deletions modelscan/models/keras/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,19 @@ def _get_keras_operator_names(
) -> List[str]:
try:
model_config_data = json.load(data)
lambda_code = [
lambda_layers = [
layer.get("config", {}).get("function", {})
for layer in model_config_data["config"]["layers"]
if layer["class_name"] == "Lambda"
for layer in model_config_data.get("config", {}).get("layers", {})
if layer.get("class_name", {}) == "Lambda"
]
if lambda_layers:
return ["Lambda"] * len(lambda_layers)

except json.JSONDecodeError as e:
logger.error(f"Not a valid JSON data from source: {source}, error: {e}")
return []

return ["Lambda"] if lambda_code else []
return []

@staticmethod
def supported_extensions() -> List[str]:
Expand Down
37 changes: 25 additions & 12 deletions modelscan/models/saved_model/scan.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# scan pb files for both tensorflow and keras

import json
import logging
from pathlib import Path

from typing import IO, List, Set, Tuple, Union, Optional, Dict
Expand All @@ -20,6 +21,8 @@
from modelscan.issues import Issue, IssueCode, IssueSeverity, OperatorIssueDetails
from modelscan.models.scan import ScanBase

logger = logging.getLogger("modelscan")


class SavedModelScan(ScanBase):
@staticmethod
Expand Down Expand Up @@ -49,7 +52,9 @@ def _scan(
# Default is a tensorflow model file
if file_name == "keras_metadata.pb":
machine_learning_library_name = "Keras"
operators_in_model = SavedModelScan._get_keras_pb_operator_names(data=data)
operators_in_model = SavedModelScan._get_keras_pb_operator_names(
data, source
)

else:
machine_learning_library_name = "Tensorflow"
Expand All @@ -62,22 +67,30 @@ def _scan(
)

@staticmethod
def _get_keras_pb_operator_names(data: IO[bytes]) -> List[str]:
def _get_keras_pb_operator_names(
data: IO[bytes], source: Union[str, Path]
) -> List[str]:
saved_metadata = SavedMetadata()
saved_metadata.ParseFromString(data.read())

lambda_code = [
layer.get("config", {}).get("function", {}).get("items", {})
for layer in [
json.loads(node.metadata)
for node in saved_metadata.nodes
if node.identifier == "_tf_keras_layer"
try:
lambda_layers = [
layer.get("config", {}).get("function", {}).get("items", {})
for layer in [
json.loads(node.metadata)
for node in saved_metadata.nodes
if node.identifier == "_tf_keras_layer"
]
if layer.get("class_name", {}) == "Lambda"
]
if layer["class_name"] == "Lambda"
]
except json.JSONDecodeError as e:
logger.error(f"Not a valid JSON data from source: {source}, error: {e}")
return []

if lambda_layers:
return ["Lambda"] * len(lambda_layers)

# if lambda code is not empty list that means there has been some code injection in Keras layer
return ["Lambda"] if lambda_code else []
return []

@staticmethod
def _get_tensorflow_operator_names(data: IO[bytes]) -> List[str]:
Expand Down
37 changes: 27 additions & 10 deletions tests/test_modelscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,7 @@
from modelscan.modelscan import Modelscan
from modelscan.cli import cli
from modelscan.error import ModelScanError
from modelscan.issues import (
Issue,
IssueCode,
IssueSeverity,
OperatorIssueDetails,
)
from modelscan.issues import Issue, IssueCode, IssueSeverity, OperatorIssueDetails
from modelscan.tools.picklescanner import (
scan_pickle_bytes,
scan_numpy,
Expand Down Expand Up @@ -220,9 +215,12 @@ def keras_file_path(tmp_path_factory: Any, keras_file_extensions: List[str]) ->
or x
)
input_to_new_layer = keras.layers.Dense(1)(keras_model.layers[-1].output)
new_layer = keras.layers.Lambda(attack)(input_to_new_layer)
first_lambda_layer = keras.layers.Lambda(attack)(input_to_new_layer)
second_lambda_layer = keras.layers.Lambda(attack)(first_lambda_layer)

malicious_model = tf.keras.Model(inputs=keras_model.inputs, outputs=[new_layer])
malicious_model = tf.keras.Model(
inputs=keras_model.inputs, outputs=[second_lambda_layer]
)
malicious_model.compile(optimizer="adam", loss="mean_squared_error")

for extension in keras_file_extensions:
Expand Down Expand Up @@ -776,7 +774,16 @@ def test_scan_keras(keras_file_path: Any, file_extension: str) -> None:
"Lambda",
f"{keras_file_path}/unsafe{file_extension}:config.json",
),
)
),
Issue(
IssueCode.UNSAFE_OPERATOR,
IssueSeverity.MEDIUM,
OperatorIssueDetails(
"Keras",
"Lambda",
f"{keras_file_path}/unsafe{file_extension}:config.json",
),
),
]
ms._scan_source(
Path(f"{keras_file_path}/unsafe{file_extension}"),
Expand All @@ -792,9 +799,19 @@ def test_scan_keras(keras_file_path: Any, file_extension: str) -> None:
"Lambda",
f"{keras_file_path}/unsafe{file_extension}",
),
)
),
Issue(
IssueCode.UNSAFE_OPERATOR,
IssueSeverity.MEDIUM,
OperatorIssueDetails(
"Keras",
"Lambda",
f"{keras_file_path}/unsafe{file_extension}",
),
),
]
ms.scan_path(Path(f"{keras_file_path}/unsafe{file_extension}"))

assert ms.issues.all_issues == expected


Expand Down

0 comments on commit 07958d0

Please sign in to comment.