diff --git a/modelscan/error.py b/modelscan/error.py index 7e227c4..d471169 100644 --- a/modelscan/error.py +++ b/modelscan/error.py @@ -1,4 +1,3 @@ -from enum import Enum from modelscan.model import Model import abc from pathlib import Path diff --git a/modelscan/issues.py b/modelscan/issues.py index 130318d..16bfb51 100644 --- a/modelscan/issues.py +++ b/modelscan/issues.py @@ -6,6 +6,8 @@ from collections import defaultdict +from modelscan.settings import Property + logger = logging.getLogger("modelscan") @@ -16,8 +18,8 @@ class IssueSeverity(Enum): CRITICAL = 4 -class IssueCode(Enum): - UNSAFE_OPERATOR = 1 +class IssueCode: + UNSAFE_OPERATOR = Property("UNSAFE_OPERATOR", 1) class IssueDetails(metaclass=abc.ABCMeta): @@ -40,14 +42,14 @@ class Issue: def __init__( self, - code: IssueCode, + code: Property, severity: IssueSeverity, details: IssueDetails, ) -> None: """ Create a issue with given information - :param code: Code of the issue from the issue code enum. + :param code: Code of the issue from the issue code class. :param severity: The severity level of the issue from Severity enum. :param details: An implementation of the IssueDetails object. """ @@ -82,7 +84,7 @@ def __hash__(self) -> int: def print(self) -> None: issue_description = self.code.name - if self.code == IssueCode.UNSAFE_OPERATOR: + if self.code.value == IssueCode.UNSAFE_OPERATOR.value: issue_description = "Unsafe operator" else: logger.error("No issue description for issue code %s", self.code) diff --git a/modelscan/scanners/h5/scan.py b/modelscan/scanners/h5/scan.py index c398535..bd088f6 100644 --- a/modelscan/scanners/h5/scan.py +++ b/modelscan/scanners/h5/scan.py @@ -18,6 +18,7 @@ from modelscan.scanners.scan import ScanResults from modelscan.scanners.saved_model.scan import SavedModelLambdaDetectScan from modelscan.model import Model +from modelscan.settings import SupportedModelFormats logger = logging.getLogger("modelscan") @@ -27,7 +28,9 @@ def scan( self, model: Model, ) -> Optional[ScanResults]: - if "keras_h5" not in model.get_context("formats"): + if SupportedModelFormats.KERAS_H5.value not in [ + format_property.value for format_property in model.get_context("formats") + ]: return None dep_error = self.handle_binary_dependencies() diff --git a/modelscan/scanners/keras/scan.py b/modelscan/scanners/keras/scan.py index 2a7fb5e..1e88c38 100644 --- a/modelscan/scanners/keras/scan.py +++ b/modelscan/scanners/keras/scan.py @@ -9,6 +9,7 @@ from modelscan.scanners.scan import ScanResults from modelscan.scanners.saved_model.scan import SavedModelLambdaDetectScan from modelscan.model import Model +from modelscan.settings import SupportedModelFormats logger = logging.getLogger("modelscan") @@ -16,7 +17,9 @@ class KerasLambdaDetectScan(SavedModelLambdaDetectScan): def scan(self, model: Model) -> Optional[ScanResults]: - if "keras" not in model.get_context("formats"): + if SupportedModelFormats.KERAS.value not in [ + format_property.value for format_property in model.get_context("formats") + ]: return None dep_error = self.handle_binary_dependencies() diff --git a/modelscan/scanners/pickle/scan.py b/modelscan/scanners/pickle/scan.py index d138202..3ece571 100644 --- a/modelscan/scanners/pickle/scan.py +++ b/modelscan/scanners/pickle/scan.py @@ -9,6 +9,7 @@ scan_pytorch, ) from modelscan.model import Model +from modelscan.settings import SupportedModelFormats logger = logging.getLogger("modelscan") @@ -18,7 +19,9 @@ def scan( self, model: Model, ) -> Optional[ScanResults]: - if "pytorch" not in model.get_context("formats"): + if SupportedModelFormats.PYTORCH.value not in [ + format_property.value for format_property in model.get_context("formats") + ]: return None if _is_zipfile(model.get_source(), model.get_stream()): @@ -45,7 +48,9 @@ def scan( self, model: Model, ) -> Optional[ScanResults]: - if "numpy" not in model.get_context("formats"): + if SupportedModelFormats.NUMPY.value not in [ + format_property.value for format_property in model.get_context("formats") + ]: return None results = scan_numpy( @@ -69,7 +74,9 @@ def scan( self, model: Model, ) -> Optional[ScanResults]: - if "pickle" not in model.get_context("formats"): + if SupportedModelFormats.PICKLE.value not in [ + format_property.value for format_property in model.get_context("formats") + ]: return None results = scan_pickle_bytes( diff --git a/modelscan/scanners/saved_model/scan.py b/modelscan/scanners/saved_model/scan.py index 74e8fb8..4c8f6f6 100644 --- a/modelscan/scanners/saved_model/scan.py +++ b/modelscan/scanners/saved_model/scan.py @@ -22,6 +22,7 @@ from modelscan.issues import Issue, IssueCode, IssueSeverity, OperatorIssueDetails from modelscan.scanners.scan import ScanBase, ScanResults from modelscan.model import Model +from modelscan.settings import SupportedModelFormats logger = logging.getLogger("modelscan") @@ -31,7 +32,9 @@ def scan( self, model: Model, ) -> Optional[ScanResults]: - if "tf_saved_model" not in model.get_context("formats"): + if SupportedModelFormats.TENSORFLOW.value not in [ + format_property.value for format_property in model.get_context("formats") + ]: return None dep_error = self.handle_binary_dependencies() diff --git a/modelscan/settings.py b/modelscan/settings.py index 5f4e6ed..395dfbe 100644 --- a/modelscan/settings.py +++ b/modelscan/settings.py @@ -4,6 +4,22 @@ from modelscan._version import __version__ + +class Property: + def __init__(self, name: str, value: Any) -> None: + self.name = name + self.value = value + + +class SupportedModelFormats: + TENSORFLOW = Property("TENSORFLOW", "tensorflow") + KERAS_H5 = Property("KERAS_H5", "keras_h5") + KERAS = Property("KERAS", "keras") + NUMPY = Property("NUMPY", "numpy") + PYTORCH = Property("PYTORCH", "pytorch") + PICKLE = Property("PICKLE", "pickle") + + DEFAULT_REPORTING_MODULES = { "console": "modelscan.reports.ConsoleReport", "json": "modelscan.reports.JSONReport", @@ -59,13 +75,12 @@ "middlewares": { "modelscan.middlewares.FormatViaExtensionMiddleware": { "formats": { - "tf": [".pb"], - "tf_saved_model": [".pb"], - "keras_h5": [".h5"], - "keras": [".keras"], - "numpy": [".npy"], - "pytorch": [".bin", ".pt", ".pth", ".ckpt"], - "pickle": [ + SupportedModelFormats.TENSORFLOW: [".pb"], + SupportedModelFormats.KERAS_H5: [".h5"], + SupportedModelFormats.KERAS: [".keras"], + SupportedModelFormats.NUMPY: [".npy"], + SupportedModelFormats.PYTORCH: [".bin", ".pt", ".pth", ".ckpt"], + SupportedModelFormats.PICKLE: [ ".pkl", ".pickle", ".joblib", diff --git a/modelscan/skip.py b/modelscan/skip.py index 27272e6..2f83b75 100644 --- a/modelscan/skip.py +++ b/modelscan/skip.py @@ -1,17 +1,18 @@ import logging from enum import Enum +from modelscan.settings import Property logger = logging.getLogger("modelscan") -class SkipCategories(Enum): - SCAN_NOT_SUPPORTED = 1 - BAD_ZIP = 2 - MODEL_CONFIG = 3 - H5_DATA = 4 - NOT_IMPLEMENTED = 5 - MAGIC_NUMBER = 6 +class SkipCategories: + SCAN_NOT_SUPPORTED = Property("SCAN_NOT_SUPPORTED", 1) + BAD_ZIP = Property("BAD_ZIP", 2) + MODEL_CONFIG = Property("MODEL_CONFIG", 3) + H5_DATA = Property("H5_DATA", 4) + NOT_IMPLEMENTED = Property("NOT_IMPLEMENTED", 5) + MAGIC_NUMBER = Property("MAGIC_NUMBER", 6) class Skip: @@ -31,7 +32,7 @@ class ModelScanSkipped: def __init__( self, scan_name: str, - category: SkipCategories, + category: Property, message: str, source: str, ) -> None: diff --git a/poetry.lock b/poetry.lock index ce145bd..8ae0e29 100644 --- a/poetry.lock +++ b/poetry.lock @@ -781,13 +781,13 @@ license = ["ukkonen"] [[package]] name = "idna" -version = "3.6" +version = "3.7" description = "Internationalized Domain Names in Applications (IDNA)" optional = false python-versions = ">=3.5" files = [ - {file = "idna-3.6-py3-none-any.whl", hash = "sha256:c05567e9c24a6b9faaa835c4821bad0590fbb9d5779e7caa6e1cc4978e7eb24f"}, - {file = "idna-3.6.tar.gz", hash = "sha256:9ecdbbd083b06798ae1e86adcbfe8ab1479cf864e4ee30fe4e46a003d12491ca"}, + {file = "idna-3.7-py3-none-any.whl", hash = "sha256:82fee1fc78add43492d3a1898bfa6d8a904cc97d8427f683ed8e798d07761aa0"}, + {file = "idna-3.7.tar.gz", hash = "sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc"}, ] [[package]] @@ -822,13 +822,13 @@ files = [ [[package]] name = "jinja2" -version = "3.1.3" +version = "3.1.4" description = "A very fast and expressive template engine." optional = false python-versions = ">=3.7" files = [ - {file = "Jinja2-3.1.3-py3-none-any.whl", hash = "sha256:7d6d50dd97d52cbc355597bd845fabfbac3f551e1f99619e39a35ce8c370b5fa"}, - {file = "Jinja2-3.1.3.tar.gz", hash = "sha256:ac8bd6544d4bb2c9792bf3a159e80bba8fda7f07e81bc3aed565432d5925ba90"}, + {file = "jinja2-3.1.4-py3-none-any.whl", hash = "sha256:bc5dd2abb727a5319567b7a813e6a2e7318c39f4f487cfe6c89c6f9c7d25197d"}, + {file = "jinja2-3.1.4.tar.gz", hash = "sha256:4a3aee7acbbe7303aede8e9648d13b8bf88a429282aa6122a993f0ac800cb369"}, ] [package.dependencies] @@ -1618,13 +1618,13 @@ files = [ [[package]] name = "requests" -version = "2.31.0" +version = "2.32.2" description = "Python HTTP for Humans." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"}, - {file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"}, + {file = "requests-2.32.2-py3-none-any.whl", hash = "sha256:fc06670dd0ed212426dfeb94fc1b983d917c4f9847c863f313c9dfaaffb7c23c"}, + {file = "requests-2.32.2.tar.gz", hash = "sha256:dd951ff5ecf3e3b3aa26b40703ba77495dab41da839ae72ef3c8e5d8e2433289"}, ] [package.dependencies] @@ -1657,13 +1657,13 @@ rsa = ["oauthlib[signedtoken] (>=3.0.0)"] [[package]] name = "rich" -version = "13.7.0" +version = "13.7.1" description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" optional = false python-versions = ">=3.7.0" files = [ - {file = "rich-13.7.0-py3-none-any.whl", hash = "sha256:6da14c108c4866ee9520bbffa71f6fe3962e193b7da68720583850cd4548e235"}, - {file = "rich-13.7.0.tar.gz", hash = "sha256:5cb5123b5cf9ee70584244246816e9114227e0b98ad9176eede6ad54bf5403fa"}, + {file = "rich-13.7.1-py3-none-any.whl", hash = "sha256:4edbae314f59eb482f54e9e30bf00d33350aaa94f4bfcd4e9e3110e64d0d7222"}, + {file = "rich-13.7.1.tar.gz", hash = "sha256:9be308cb1fe2f1f57d67ce99e95af38a1e2bc71ad9813b0e247cf7ffbcc3a432"}, ] [package.dependencies] @@ -2027,13 +2027,13 @@ tutorials = ["matplotlib", "pandas", "tabulate"] [[package]] name = "types-requests" -version = "2.31.0.20240106" +version = "2.31.0.20240406" description = "Typing stubs for requests" optional = false python-versions = ">=3.8" files = [ - {file = "types-requests-2.31.0.20240106.tar.gz", hash = "sha256:0e1c731c17f33618ec58e022b614a1a2ecc25f7dc86800b36ef341380402c612"}, - {file = "types_requests-2.31.0.20240106-py3-none-any.whl", hash = "sha256:da997b3b6a72cc08d09f4dba9802fdbabc89104b35fe24ee588e674037689354"}, + {file = "types-requests-2.31.0.20240406.tar.gz", hash = "sha256:4428df33c5503945c74b3f42e82b181e86ec7b724620419a2966e2de604ce1a1"}, + {file = "types_requests-2.31.0.20240406-py3-none-any.whl", hash = "sha256:6216cdac377c6b9a040ac1c0404f7284bd13199c0e1bb235f4324627e8898cf5"}, ] [package.dependencies] @@ -2088,13 +2088,13 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess [[package]] name = "werkzeug" -version = "3.0.1" +version = "3.0.3" description = "The comprehensive WSGI web application library." optional = true python-versions = ">=3.8" files = [ - {file = "werkzeug-3.0.1-py3-none-any.whl", hash = "sha256:90a285dc0e42ad56b34e696398b8122ee4c681833fb35b8334a095d82c56da10"}, - {file = "werkzeug-3.0.1.tar.gz", hash = "sha256:507e811ecea72b18a404947aded4b3390e1db8f826b494d76550ef45bb3b1dcc"}, + {file = "werkzeug-3.0.3-py3-none-any.whl", hash = "sha256:fc9645dc43e03e4d630d23143a04a7f947a9a3b5727cd535fdfe155a17cc48c8"}, + {file = "werkzeug-3.0.3.tar.gz", hash = "sha256:097e5bfda9f0aba8da6b8545146def481d06aa7d3266e7448e2cccf67dd8bd18"}, ] [package.dependencies]