diff --git a/modelscan/scanners/h5/scan.py b/modelscan/scanners/h5/scan.py index e6ef3db..bd088f6 100644 --- a/modelscan/scanners/h5/scan.py +++ b/modelscan/scanners/h5/scan.py @@ -28,7 +28,9 @@ def scan( self, model: Model, ) -> Optional[ScanResults]: - if SupportedModelFormats.KERAS_H5 not in model.get_context("formats"): + if SupportedModelFormats.KERAS_H5.value not in [ + format_property.value for format_property in model.get_context("formats") + ]: return None dep_error = self.handle_binary_dependencies() diff --git a/modelscan/scanners/keras/scan.py b/modelscan/scanners/keras/scan.py index 58fa40f..1e88c38 100644 --- a/modelscan/scanners/keras/scan.py +++ b/modelscan/scanners/keras/scan.py @@ -17,7 +17,9 @@ class KerasLambdaDetectScan(SavedModelLambdaDetectScan): def scan(self, model: Model) -> Optional[ScanResults]: - if SupportedModelFormats.KERAS not in model.get_context("formats"): + if SupportedModelFormats.KERAS.value not in [ + format_property.value for format_property in model.get_context("formats") + ]: return None dep_error = self.handle_binary_dependencies() diff --git a/modelscan/scanners/pickle/scan.py b/modelscan/scanners/pickle/scan.py index a55705c..3ece571 100644 --- a/modelscan/scanners/pickle/scan.py +++ b/modelscan/scanners/pickle/scan.py @@ -19,7 +19,9 @@ def scan( self, model: Model, ) -> Optional[ScanResults]: - if SupportedModelFormats.PYTORCH not in model.get_context("formats"): + if SupportedModelFormats.PYTORCH.value not in [ + format_property.value for format_property in model.get_context("formats") + ]: return None if _is_zipfile(model.get_source(), model.get_stream()): @@ -46,7 +48,9 @@ def scan( self, model: Model, ) -> Optional[ScanResults]: - if SupportedModelFormats.NUMPY not in model.get_context("formats"): + if SupportedModelFormats.NUMPY.value not in [ + format_property.value for format_property in model.get_context("formats") + ]: return None results = scan_numpy( @@ -70,7 +74,9 @@ def scan( self, model: Model, ) -> Optional[ScanResults]: - if SupportedModelFormats.PICKLE not in model.get_context("formats"): + if SupportedModelFormats.PICKLE.value not in [ + format_property.value for format_property in model.get_context("formats") + ]: return None results = scan_pickle_bytes( diff --git a/modelscan/scanners/saved_model/scan.py b/modelscan/scanners/saved_model/scan.py index ff3338e..4c8f6f6 100644 --- a/modelscan/scanners/saved_model/scan.py +++ b/modelscan/scanners/saved_model/scan.py @@ -32,7 +32,9 @@ def scan( self, model: Model, ) -> Optional[ScanResults]: - if SupportedModelFormats.TENSORFLOW not in model.get_context("formats"): + if SupportedModelFormats.TENSORFLOW.value not in [ + format_property.value for format_property in model.get_context("formats") + ]: return None dep_error = self.handle_binary_dependencies()