Skip to content

Commit

Permalink
Add tests for pb/SavedModel
Browse files Browse the repository at this point in the history
  • Loading branch information
mehrinkiani committed Nov 29, 2023
1 parent 572df37 commit 26da99d
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 11 deletions.
129 changes: 118 additions & 11 deletions tests/test_modelscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
import tensorflow as tf
from tensorflow import keras
from typing import Any, List, Set
from test_utils import generate_dill_unsafe_file, generate_unsafe_pickle_file
from test_utils import (
generate_dill_unsafe_file,
generate_unsafe_pickle_file,
MaliciousModule,
)
import zipfile

from modelscan.modelscan import Modelscan
Expand All @@ -32,6 +36,8 @@
scan_numpy,
)

from modelscan.models.saved_model.scan import SavedModelScan


class Malicious1:
def __reduce__(self) -> Any:
Expand Down Expand Up @@ -180,9 +186,41 @@ def file_path(tmp_path_factory: Any) -> Any:
return tmp


@pytest.fixture(scope="session")
def tensorflow_file_path(tmp_path_factory: Any) -> Any:
# Create a simple model.
inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
tensorflow_model = keras.Model(inputs, outputs)
tensorflow_model.compile(optimizer="adam", loss="mean_squared_error")

# Train a model
test_input = np.random.random((128, 32))
test_target = np.random.random((128, 1))
tensorflow_model.fit(test_input, test_target)

# Save the safe model
tmp = tmp_path_factory.mktemp("tensorflow")
safe_tensorflow_model_dir = tmp / "saved_model_safe"
safe_tensorflow_model_dir.mkdir(parents=True)
tensorflow_model.save(safe_tensorflow_model_dir)

# Create an unsafe model
unsafe_tensorflow_model = MaliciousModule(tensorflow_model)
unsafe_tensorflow_model.build(input_shape=(32, 32))

# Save the unsafe model
unsafe_tensorflow_model_dir = tmp / "saved_model_unsafe"
unsafe_tensorflow_model_dir.mkdir(parents=True)
unsafe_model_path = os.path.join(unsafe_tensorflow_model_dir)
unsafe_tensorflow_model.save(unsafe_model_path)

return safe_tensorflow_model_dir, unsafe_tensorflow_model_dir


@pytest.fixture(scope="session")
def keras_file_extensions() -> List[str]:
return ["h5", "keras"]
return ["h5", "keras", "pb"]


@pytest.fixture(scope="session")
Expand All @@ -203,7 +241,12 @@ def keras_file_path(tmp_path_factory: Any, keras_file_extensions: List[str]) ->
with open(f"{tmp}/safe", "wb") as fo:
pickle.dump(keras_model, fo)
for extension in keras_file_extensions:
keras_model.save(f"{tmp}/safe.{extension}")
if extension == "pb":
safe_saved_model_dir = tmp / "saved_model_safe"
safe_saved_model_dir.mkdir(parents=True)
keras_model.save(f"{safe_saved_model_dir}")
else:
keras_model.save(f"{tmp}/safe.{extension}")

# Inject code with the command
command = "exec"
Expand All @@ -226,9 +269,14 @@ def keras_file_path(tmp_path_factory: Any, keras_file_extensions: List[str]) ->
malicious_model.compile(optimizer="adam", loss="mean_squared_error")

for extension in keras_file_extensions:
malicious_model.save(f"{tmp}/unsafe.{extension}")
if extension == "pb":
unsafe_saved_model_dir = tmp / "saved_model_unsafe"
unsafe_saved_model_dir.mkdir(parents=True)
malicious_model.save(f"{unsafe_saved_model_dir}")
else:
malicious_model.save(f"{tmp}/unsafe.{extension}")

return tmp
return tmp, safe_saved_model_dir, unsafe_saved_model_dir


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -760,10 +808,20 @@ def test_scan_directory_path(file_path: str) -> None:
compare_results(ms.issues.all_issues, expected)


@pytest.mark.parametrize("file_extension", [".h5", ".keras"], ids=["h5", "keras"])
@pytest.mark.parametrize(
"file_extension", [".h5", ".keras", ".pb"], ids=["h5", "keras", "pb"]
)
def test_scan_keras(keras_file_path: Any, file_extension: str) -> None:
keras_file_path_parent_dir, safe_saved_model_dir, unsafe_saved_model_dir = (
keras_file_path[0],
keras_file_path[1],
keras_file_path[2],
)
ms = Modelscan()
ms.scan_path(Path(f"{keras_file_path}/safe{file_extension}"))
if file_extension == ".pb":
ms.scan_path(Path(f"{safe_saved_model_dir}"))
else:
ms.scan_path(Path(f"{keras_file_path_parent_dir}/safe{file_extension}"))
assert ms.issues.all_issues == []

if file_extension == ".keras":
Expand All @@ -774,14 +832,28 @@ def test_scan_keras(keras_file_path: Any, file_extension: str) -> None:
OperatorIssueDetails(
"Keras",
"Lambda",
f"{keras_file_path}/unsafe{file_extension}:config.json",
f"{keras_file_path_parent_dir}/unsafe{file_extension}:config.json",
),
)
]
ms._scan_source(
Path(f"{keras_file_path}/unsafe{file_extension}"),
Path(f"{keras_file_path_parent_dir}/unsafe{file_extension}"),
extension=file_extension,
)
elif file_extension == ".pb":
file_name = "keras_metadata.pb"
expected = [
Issue(
IssueCode.UNSAFE_OPERATOR,
IssueSeverity.MEDIUM,
OperatorIssueDetails(
"Keras",
"Lambda",
f"{unsafe_saved_model_dir}/{file_name}",
),
),
]
ms.scan_path(Path(f"{unsafe_saved_model_dir}"))
else:
expected = [
Issue(
Expand All @@ -790,11 +862,46 @@ def test_scan_keras(keras_file_path: Any, file_extension: str) -> None:
OperatorIssueDetails(
"Keras",
"Lambda",
f"{keras_file_path}/unsafe{file_extension}",
f"{keras_file_path_parent_dir}/unsafe{file_extension}",
),
)
]
ms.scan_path(Path(f"{keras_file_path}/unsafe{file_extension}"))
ms.scan_path(Path(f"{keras_file_path_parent_dir}/unsafe{file_extension}"))
assert ms.issues.all_issues == expected


def test_scan_tensorflow(tensorflow_file_path: Any) -> None:
safe_tensorflow_model_dir, unsafe_tensorflow_model_dir = (
tensorflow_file_path[0],
tensorflow_file_path[1],
)
ms = Modelscan()
ms.scan_path(Path(f"{safe_tensorflow_model_dir}"))
assert ms.issues.all_issues == []

file_name = "saved_model.pb"
expected = [
Issue(
IssueCode.UNSAFE_OPERATOR,
IssueSeverity.HIGH,
OperatorIssueDetails(
"Tensorflow",
"ReadFile",
f"{unsafe_tensorflow_model_dir}/{file_name}",
),
),
Issue(
IssueCode.UNSAFE_OPERATOR,
IssueSeverity.HIGH,
OperatorIssueDetails(
"Tensorflow",
"WriteFile",
f"{unsafe_tensorflow_model_dir}/{file_name}",
),
),
]
ms.scan_path(Path(f"{unsafe_tensorflow_model_dir}"))

assert ms.issues.all_issues == expected


Expand Down
27 changes: 27 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,33 @@
import struct
from typing import Any, Tuple
import os
import tensorflow as tf
from tensorflow import keras


class MaliciousModule(keras.Model): # type: ignore
def __init__(self, safe_model) -> None: # type: ignore
super(MaliciousModule, self).__init__()
self.model = safe_model

@tf.function(input_signature=[tf.TensorSpec(shape=(32, 32), dtype=tf.float32)]) # type: ignore
def call(self, x: float) -> Any:
# Some model prediction logic
res = self.model(x)

# Write a file
tf.io.write_file(
"/tmp/aws_secret.txt",
"aws_access_key_id=<access_key_id>\naws_secret_access_key=<aws_secret_key>",
)

list_ds = tf.data.Dataset.list_files("/tmp/*.txt", shuffle=False)

for file in list_ds:
tf.print("File found: " + file)
tf.print(tf.io.read_file(file))

return res


class PickleInject:
Expand Down

0 comments on commit 26da99d

Please sign in to comment.