diff --git a/tests/test_modelscan.py b/tests/test_modelscan.py index 1900bad..439e270 100644 --- a/tests/test_modelscan.py +++ b/tests/test_modelscan.py @@ -19,6 +19,7 @@ from test_utils import ( generate_dill_unsafe_file, generate_unsafe_pickle_file, + MaliciousModule, PyTorchTestModel, ) import zipfile @@ -39,6 +40,8 @@ settings: Dict[str, Any] = DEFAULT_SETTINGS +from modelscan.scanners.saved_model.scan import SavedModelScan + class Malicious1: def __reduce__(self) -> Any: @@ -288,9 +291,41 @@ def file_path(tmp_path_factory: Any) -> Any: return tmp +@pytest.fixture(scope="session") +def tensorflow_file_path(tmp_path_factory: Any) -> Any: + # Create a simple model. + inputs = keras.Input(shape=(32,)) + outputs = keras.layers.Dense(1)(inputs) + tensorflow_model = keras.Model(inputs, outputs) + tensorflow_model.compile(optimizer="adam", loss="mean_squared_error") + + # Train a model + test_input = np.random.random((128, 32)) + test_target = np.random.random((128, 1)) + tensorflow_model.fit(test_input, test_target) + + # Save the safe model + tmp = tmp_path_factory.mktemp("tensorflow") + safe_tensorflow_model_dir = tmp / "saved_model_safe" + safe_tensorflow_model_dir.mkdir(parents=True) + tensorflow_model.save(safe_tensorflow_model_dir) + + # Create an unsafe model + unsafe_tensorflow_model = MaliciousModule(tensorflow_model) + unsafe_tensorflow_model.build(input_shape=(32, 32)) + + # Save the unsafe model + unsafe_tensorflow_model_dir = tmp / "saved_model_unsafe" + unsafe_tensorflow_model_dir.mkdir(parents=True) + unsafe_model_path = os.path.join(unsafe_tensorflow_model_dir) + unsafe_tensorflow_model.save(unsafe_model_path) + + return safe_tensorflow_model_dir, unsafe_tensorflow_model_dir + + @pytest.fixture(scope="session") def keras_file_extensions() -> List[str]: - return ["h5", "keras"] + return ["h5", "keras", "pb"] @pytest.fixture(scope="session") @@ -311,7 +346,12 @@ def keras_file_path(tmp_path_factory: Any, keras_file_extensions: List[str]) -> with open(f"{tmp}/safe", "wb") as fo: pickle.dump(keras_model, fo) for extension in keras_file_extensions: - keras_model.save(f"{tmp}/safe.{extension}") + if extension == "pb": + safe_saved_model_dir = tmp / "saved_model_safe" + safe_saved_model_dir.mkdir(parents=True) + keras_model.save(f"{safe_saved_model_dir}") + else: + keras_model.save(f"{tmp}/safe.{extension}") # Inject code with the command command = "exec" @@ -337,9 +377,14 @@ def keras_file_path(tmp_path_factory: Any, keras_file_extensions: List[str]) -> malicious_model.compile(optimizer="adam", loss="mean_squared_error") for extension in keras_file_extensions: - malicious_model.save(f"{tmp}/unsafe.{extension}") + if extension == "pb": + unsafe_saved_model_dir = tmp / "saved_model_unsafe" + unsafe_saved_model_dir.mkdir(parents=True) + malicious_model.save(f"{unsafe_saved_model_dir}") + else: + malicious_model.save(f"{tmp}/unsafe.{extension}") - return tmp + return tmp, safe_saved_model_dir, unsafe_saved_model_dir @pytest.fixture(scope="session") @@ -950,10 +995,21 @@ def test_scan_directory_path(file_path: str) -> None: compare_results(ms.issues.all_issues, expected) -@pytest.mark.parametrize("file_extension", [".h5", ".keras"], ids=["h5", "keras"]) +@pytest.mark.parametrize( + "file_extension", [".h5", ".keras", ".pb"], ids=["h5", "keras", "pb"] +) def test_scan_keras(keras_file_path: Any, file_extension: str) -> None: + keras_file_path_parent_dir, safe_saved_model_dir, unsafe_saved_model_dir = ( + keras_file_path[0], + keras_file_path[1], + keras_file_path[2], + ) ms = ModelScan() - ms.scan(Path(f"{keras_file_path}/safe{file_extension}")) + if file_extension == ".pb": + ms.scan(Path(f"{safe_saved_model_dir}")) + else: + ms.scan(Path(f"{keras_file_path_parent_dir}/safe{file_extension}")) + assert ms.issues.all_issues == [] if file_extension == ".keras": @@ -964,7 +1020,7 @@ def test_scan_keras(keras_file_path: Any, file_extension: str) -> None: OperatorIssueDetails( "Keras", "Lambda", - f"{keras_file_path}/unsafe{file_extension}:config.json", + f"{keras_file_path_parent_dir}/unsafe{file_extension}:config.json", ), ), Issue( @@ -973,13 +1029,34 @@ def test_scan_keras(keras_file_path: Any, file_extension: str) -> None: OperatorIssueDetails( "Keras", "Lambda", - f"{keras_file_path}/unsafe{file_extension}:config.json", + f"{keras_file_path_parent_dir}/unsafe{file_extension}:config.json", ), ), ] - ms._scan_source( - Path(f"{keras_file_path}/unsafe{file_extension}"), - ) + ms.scan(Path(f"{keras_file_path_parent_dir}/unsafe{file_extension}")) + elif file_extension == ".pb": + file_name = "keras_metadata.pb" + expected = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.MEDIUM, + OperatorIssueDetails( + "Keras", + "Lambda", + f"{unsafe_saved_model_dir}/{file_name}", + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.MEDIUM, + OperatorIssueDetails( + "Keras", + "Lambda", + f"{unsafe_saved_model_dir}/{file_name}", + ), + ), + ] + ms.scan(Path(f"{unsafe_saved_model_dir}")) else: expected = [ Issue( @@ -988,7 +1065,7 @@ def test_scan_keras(keras_file_path: Any, file_extension: str) -> None: OperatorIssueDetails( "Keras", "Lambda", - f"{keras_file_path}/unsafe{file_extension}", + f"{keras_file_path_parent_dir}/unsafe{file_extension}", ), ), Issue( @@ -997,11 +1074,46 @@ def test_scan_keras(keras_file_path: Any, file_extension: str) -> None: OperatorIssueDetails( "Keras", "Lambda", - f"{keras_file_path}/unsafe{file_extension}", + f"{keras_file_path_parent_dir}/unsafe{file_extension}", ), ), ] - ms._scan_path(Path(f"{keras_file_path}/unsafe{file_extension}")) + + ms.scan(Path(f"{keras_file_path_parent_dir}/unsafe{file_extension}")) + assert ms.issues.all_issues == expected + + +def test_scan_tensorflow(tensorflow_file_path: Any) -> None: + safe_tensorflow_model_dir, unsafe_tensorflow_model_dir = ( + tensorflow_file_path[0], + tensorflow_file_path[1], + ) + ms = ModelScan() + ms.scan(Path(f"{safe_tensorflow_model_dir}")) + assert ms.issues.all_issues == [] + + file_name = "saved_model.pb" + expected = [ + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.HIGH, + OperatorIssueDetails( + "Tensorflow", + "ReadFile", + f"{unsafe_tensorflow_model_dir}/{file_name}", + ), + ), + Issue( + IssueCode.UNSAFE_OPERATOR, + IssueSeverity.HIGH, + OperatorIssueDetails( + "Tensorflow", + "WriteFile", + f"{unsafe_tensorflow_model_dir}/{file_name}", + ), + ), + ] + ms.scan(Path(f"{unsafe_tensorflow_model_dir}")) assert ms.issues.all_issues == expected diff --git a/tests/test_utils.py b/tests/test_utils.py index 4bc4100..6fedb5e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -6,6 +6,33 @@ import os import torch import torch.nn as nn +import tensorflow as tf +from tensorflow import keras + + +class MaliciousModule(keras.Model): # type: ignore + def __init__(self, safe_model) -> None: # type: ignore + super(MaliciousModule, self).__init__() + self.model = safe_model + + @tf.function(input_signature=[tf.TensorSpec(shape=(32, 32), dtype=tf.float32)]) # type: ignore + def call(self, x: float) -> Any: + # Some model prediction logic + res = self.model(x) + + # Write a file + tf.io.write_file( + "/tmp/aws_secret.txt", + "aws_access_key_id=\naws_secret_access_key=", + ) + + list_ds = tf.data.Dataset.list_files("/tmp/*.txt", shuffle=False) + + for file in list_ds: + tf.print("File found: " + file) + tf.print(tf.io.read_file(file)) + + return res class PickleInject: