Skip to content

Commit

Permalink
Merge branch 'main' into dependabot/pip/mypy-1.9.0
Browse files Browse the repository at this point in the history
  • Loading branch information
seanpmorgan authored May 30, 2024
2 parents 725e8f5 + b9117e1 commit 37aa4f4
Show file tree
Hide file tree
Showing 9 changed files with 79 additions and 46 deletions.
1 change: 0 additions & 1 deletion modelscan/error.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from enum import Enum
from modelscan.model import Model
import abc
from pathlib import Path
Expand Down
12 changes: 7 additions & 5 deletions modelscan/issues.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from collections import defaultdict

from modelscan.settings import Property

logger = logging.getLogger("modelscan")


Expand All @@ -16,8 +18,8 @@ class IssueSeverity(Enum):
CRITICAL = 4


class IssueCode(Enum):
UNSAFE_OPERATOR = 1
class IssueCode:
UNSAFE_OPERATOR = Property("UNSAFE_OPERATOR", 1)


class IssueDetails(metaclass=abc.ABCMeta):
Expand All @@ -40,14 +42,14 @@ class Issue:

def __init__(
self,
code: IssueCode,
code: Property,
severity: IssueSeverity,
details: IssueDetails,
) -> None:
"""
Create a issue with given information
:param code: Code of the issue from the issue code enum.
:param code: Code of the issue from the issue code class.
:param severity: The severity level of the issue from Severity enum.
:param details: An implementation of the IssueDetails object.
"""
Expand Down Expand Up @@ -82,7 +84,7 @@ def __hash__(self) -> int:

def print(self) -> None:
issue_description = self.code.name
if self.code == IssueCode.UNSAFE_OPERATOR:
if self.code.value == IssueCode.UNSAFE_OPERATOR.value:
issue_description = "Unsafe operator"
else:
logger.error("No issue description for issue code %s", self.code)
Expand Down
5 changes: 4 additions & 1 deletion modelscan/scanners/h5/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from modelscan.scanners.scan import ScanResults
from modelscan.scanners.saved_model.scan import SavedModelLambdaDetectScan
from modelscan.model import Model
from modelscan.settings import SupportedModelFormats

logger = logging.getLogger("modelscan")

Expand All @@ -27,7 +28,9 @@ def scan(
self,
model: Model,
) -> Optional[ScanResults]:
if "keras_h5" not in model.get_context("formats"):
if SupportedModelFormats.KERAS_H5.value not in [
format_property.value for format_property in model.get_context("formats")
]:
return None

dep_error = self.handle_binary_dependencies()
Expand Down
5 changes: 4 additions & 1 deletion modelscan/scanners/keras/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,17 @@
from modelscan.scanners.scan import ScanResults
from modelscan.scanners.saved_model.scan import SavedModelLambdaDetectScan
from modelscan.model import Model
from modelscan.settings import SupportedModelFormats


logger = logging.getLogger("modelscan")


class KerasLambdaDetectScan(SavedModelLambdaDetectScan):
def scan(self, model: Model) -> Optional[ScanResults]:
if "keras" not in model.get_context("formats"):
if SupportedModelFormats.KERAS.value not in [
format_property.value for format_property in model.get_context("formats")
]:
return None

dep_error = self.handle_binary_dependencies()
Expand Down
13 changes: 10 additions & 3 deletions modelscan/scanners/pickle/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
scan_pytorch,
)
from modelscan.model import Model
from modelscan.settings import SupportedModelFormats

logger = logging.getLogger("modelscan")

Expand All @@ -18,7 +19,9 @@ def scan(
self,
model: Model,
) -> Optional[ScanResults]:
if "pytorch" not in model.get_context("formats"):
if SupportedModelFormats.PYTORCH.value not in [
format_property.value for format_property in model.get_context("formats")
]:
return None

if _is_zipfile(model.get_source(), model.get_stream()):
Expand All @@ -45,7 +48,9 @@ def scan(
self,
model: Model,
) -> Optional[ScanResults]:
if "numpy" not in model.get_context("formats"):
if SupportedModelFormats.NUMPY.value not in [
format_property.value for format_property in model.get_context("formats")
]:
return None

results = scan_numpy(
Expand All @@ -69,7 +74,9 @@ def scan(
self,
model: Model,
) -> Optional[ScanResults]:
if "pickle" not in model.get_context("formats"):
if SupportedModelFormats.PICKLE.value not in [
format_property.value for format_property in model.get_context("formats")
]:
return None

results = scan_pickle_bytes(
Expand Down
5 changes: 4 additions & 1 deletion modelscan/scanners/saved_model/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from modelscan.issues import Issue, IssueCode, IssueSeverity, OperatorIssueDetails
from modelscan.scanners.scan import ScanBase, ScanResults
from modelscan.model import Model
from modelscan.settings import SupportedModelFormats

logger = logging.getLogger("modelscan")

Expand All @@ -31,7 +32,9 @@ def scan(
self,
model: Model,
) -> Optional[ScanResults]:
if "tf_saved_model" not in model.get_context("formats"):
if SupportedModelFormats.TENSORFLOW.value not in [
format_property.value for format_property in model.get_context("formats")
]:
return None

dep_error = self.handle_binary_dependencies()
Expand Down
29 changes: 22 additions & 7 deletions modelscan/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,22 @@

from modelscan._version import __version__


class Property:
def __init__(self, name: str, value: Any) -> None:
self.name = name
self.value = value


class SupportedModelFormats:
TENSORFLOW = Property("TENSORFLOW", "tensorflow")
KERAS_H5 = Property("KERAS_H5", "keras_h5")
KERAS = Property("KERAS", "keras")
NUMPY = Property("NUMPY", "numpy")
PYTORCH = Property("PYTORCH", "pytorch")
PICKLE = Property("PICKLE", "pickle")


DEFAULT_REPORTING_MODULES = {
"console": "modelscan.reports.ConsoleReport",
"json": "modelscan.reports.JSONReport",
Expand Down Expand Up @@ -59,13 +75,12 @@
"middlewares": {
"modelscan.middlewares.FormatViaExtensionMiddleware": {
"formats": {
"tf": [".pb"],
"tf_saved_model": [".pb"],
"keras_h5": [".h5"],
"keras": [".keras"],
"numpy": [".npy"],
"pytorch": [".bin", ".pt", ".pth", ".ckpt"],
"pickle": [
SupportedModelFormats.TENSORFLOW: [".pb"],
SupportedModelFormats.KERAS_H5: [".h5"],
SupportedModelFormats.KERAS: [".keras"],
SupportedModelFormats.NUMPY: [".npy"],
SupportedModelFormats.PYTORCH: [".bin", ".pt", ".pth", ".ckpt"],
SupportedModelFormats.PICKLE: [
".pkl",
".pickle",
".joblib",
Expand Down
17 changes: 9 additions & 8 deletions modelscan/skip.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import logging
from enum import Enum

from modelscan.settings import Property

logger = logging.getLogger("modelscan")


class SkipCategories(Enum):
SCAN_NOT_SUPPORTED = 1
BAD_ZIP = 2
MODEL_CONFIG = 3
H5_DATA = 4
NOT_IMPLEMENTED = 5
MAGIC_NUMBER = 6
class SkipCategories:
SCAN_NOT_SUPPORTED = Property("SCAN_NOT_SUPPORTED", 1)
BAD_ZIP = Property("BAD_ZIP", 2)
MODEL_CONFIG = Property("MODEL_CONFIG", 3)
H5_DATA = Property("H5_DATA", 4)
NOT_IMPLEMENTED = Property("NOT_IMPLEMENTED", 5)
MAGIC_NUMBER = Property("MAGIC_NUMBER", 6)


class Skip:
Expand All @@ -31,7 +32,7 @@ class ModelScanSkipped:
def __init__(
self,
scan_name: str,
category: SkipCategories,
category: Property,
message: str,
source: str,
) -> None:
Expand Down
38 changes: 19 additions & 19 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 37aa4f4

Please sign in to comment.