From 1a8aff723aa235a70a8ddc030949c7d652e5e931 Mon Sep 17 00:00:00 2001 From: Sam Washko Date: Thu, 23 May 2024 14:42:41 -0700 Subject: [PATCH] compare by property value --- modelscan/scanners/h5/scan.py | 4 +++- modelscan/scanners/keras/scan.py | 4 +++- modelscan/scanners/pickle/scan.py | 12 +++++++++--- modelscan/scanners/saved_model/scan.py | 4 +++- 4 files changed, 18 insertions(+), 6 deletions(-) diff --git a/modelscan/scanners/h5/scan.py b/modelscan/scanners/h5/scan.py index e6ef3db..bd088f6 100644 --- a/modelscan/scanners/h5/scan.py +++ b/modelscan/scanners/h5/scan.py @@ -28,7 +28,9 @@ def scan( self, model: Model, ) -> Optional[ScanResults]: - if SupportedModelFormats.KERAS_H5 not in model.get_context("formats"): + if SupportedModelFormats.KERAS_H5.value not in [ + format_property.value for format_property in model.get_context("formats") + ]: return None dep_error = self.handle_binary_dependencies() diff --git a/modelscan/scanners/keras/scan.py b/modelscan/scanners/keras/scan.py index 58fa40f..1e88c38 100644 --- a/modelscan/scanners/keras/scan.py +++ b/modelscan/scanners/keras/scan.py @@ -17,7 +17,9 @@ class KerasLambdaDetectScan(SavedModelLambdaDetectScan): def scan(self, model: Model) -> Optional[ScanResults]: - if SupportedModelFormats.KERAS not in model.get_context("formats"): + if SupportedModelFormats.KERAS.value not in [ + format_property.value for format_property in model.get_context("formats") + ]: return None dep_error = self.handle_binary_dependencies() diff --git a/modelscan/scanners/pickle/scan.py b/modelscan/scanners/pickle/scan.py index a55705c..3ece571 100644 --- a/modelscan/scanners/pickle/scan.py +++ b/modelscan/scanners/pickle/scan.py @@ -19,7 +19,9 @@ def scan( self, model: Model, ) -> Optional[ScanResults]: - if SupportedModelFormats.PYTORCH not in model.get_context("formats"): + if SupportedModelFormats.PYTORCH.value not in [ + format_property.value for format_property in model.get_context("formats") + ]: return None if _is_zipfile(model.get_source(), model.get_stream()): @@ -46,7 +48,9 @@ def scan( self, model: Model, ) -> Optional[ScanResults]: - if SupportedModelFormats.NUMPY not in model.get_context("formats"): + if SupportedModelFormats.NUMPY.value not in [ + format_property.value for format_property in model.get_context("formats") + ]: return None results = scan_numpy( @@ -70,7 +74,9 @@ def scan( self, model: Model, ) -> Optional[ScanResults]: - if SupportedModelFormats.PICKLE not in model.get_context("formats"): + if SupportedModelFormats.PICKLE.value not in [ + format_property.value for format_property in model.get_context("formats") + ]: return None results = scan_pickle_bytes( diff --git a/modelscan/scanners/saved_model/scan.py b/modelscan/scanners/saved_model/scan.py index ff3338e..4c8f6f6 100644 --- a/modelscan/scanners/saved_model/scan.py +++ b/modelscan/scanners/saved_model/scan.py @@ -32,7 +32,9 @@ def scan( self, model: Model, ) -> Optional[ScanResults]: - if SupportedModelFormats.TENSORFLOW not in model.get_context("formats"): + if SupportedModelFormats.TENSORFLOW.value not in [ + format_property.value for format_property in model.get_context("formats") + ]: return None dep_error = self.handle_binary_dependencies()