Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace enums to be extendable #151

Merged
merged 3 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion modelscan/error.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from enum import Enum
from modelscan.model import Model
import abc
from pathlib import Path
Expand Down
12 changes: 7 additions & 5 deletions modelscan/issues.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from collections import defaultdict

from modelscan.settings import Property

logger = logging.getLogger("modelscan")


Expand All @@ -16,8 +18,8 @@ class IssueSeverity(Enum):
CRITICAL = 4


class IssueCode(Enum):
UNSAFE_OPERATOR = 1
class IssueCode:
UNSAFE_OPERATOR = Property("UNSAFE_OPERATOR", 1)


class IssueDetails(metaclass=abc.ABCMeta):
Expand All @@ -40,14 +42,14 @@ class Issue:

def __init__(
self,
code: IssueCode,
code: Property,
severity: IssueSeverity,
details: IssueDetails,
) -> None:
"""
Create a issue with given information

:param code: Code of the issue from the issue code enum.
:param code: Code of the issue from the issue code class.
:param severity: The severity level of the issue from Severity enum.
:param details: An implementation of the IssueDetails object.
"""
Expand Down Expand Up @@ -82,7 +84,7 @@ def __hash__(self) -> int:

def print(self) -> None:
issue_description = self.code.name
if self.code == IssueCode.UNSAFE_OPERATOR:
if self.code.value == IssueCode.UNSAFE_OPERATOR.value:
issue_description = "Unsafe operator"
else:
logger.error("No issue description for issue code %s", self.code)
Expand Down
6 changes: 4 additions & 2 deletions modelscan/scanners/h5/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from modelscan.scanners.scan import ScanResults
from modelscan.scanners.saved_model.scan import SavedModelLambdaDetectScan
from modelscan.model import Model
from modelscan.settings import DefaultModelFormats
from modelscan.settings import SupportedModelFormats

logger = logging.getLogger("modelscan")

Expand All @@ -28,7 +28,9 @@ def scan(
self,
model: Model,
) -> Optional[ScanResults]:
if DefaultModelFormats.KERAS_H5 not in model.get_context("formats"):
if SupportedModelFormats.KERAS_H5.value not in [
format_property.value for format_property in model.get_context("formats")
]:
return None

dep_error = self.handle_binary_dependencies()
Expand Down
6 changes: 4 additions & 2 deletions modelscan/scanners/keras/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,17 @@
from modelscan.scanners.scan import ScanResults
from modelscan.scanners.saved_model.scan import SavedModelLambdaDetectScan
from modelscan.model import Model
from modelscan.settings import DefaultModelFormats
from modelscan.settings import SupportedModelFormats


logger = logging.getLogger("modelscan")


class KerasLambdaDetectScan(SavedModelLambdaDetectScan):
def scan(self, model: Model) -> Optional[ScanResults]:
if DefaultModelFormats.KERAS not in model.get_context("formats"):
if SupportedModelFormats.KERAS.value not in [
format_property.value for format_property in model.get_context("formats")
]:
return None

dep_error = self.handle_binary_dependencies()
Expand Down
14 changes: 10 additions & 4 deletions modelscan/scanners/pickle/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
scan_pytorch,
)
from modelscan.model import Model
from modelscan.settings import DefaultModelFormats
from modelscan.settings import SupportedModelFormats

logger = logging.getLogger("modelscan")

Expand All @@ -19,7 +19,9 @@ def scan(
self,
model: Model,
) -> Optional[ScanResults]:
if DefaultModelFormats.PYTORCH not in model.get_context("formats"):
if SupportedModelFormats.PYTORCH.value not in [
format_property.value for format_property in model.get_context("formats")
]:
return None

if _is_zipfile(model.get_source(), model.get_stream()):
Expand All @@ -46,7 +48,9 @@ def scan(
self,
model: Model,
) -> Optional[ScanResults]:
if DefaultModelFormats.NUMPY not in model.get_context("formats"):
if SupportedModelFormats.NUMPY.value not in [
format_property.value for format_property in model.get_context("formats")
]:
return None

results = scan_numpy(
Expand All @@ -70,7 +74,9 @@ def scan(
self,
model: Model,
) -> Optional[ScanResults]:
if DefaultModelFormats.PICKLE not in model.get_context("formats"):
if SupportedModelFormats.PICKLE.value not in [
format_property.value for format_property in model.get_context("formats")
]:
return None

results = scan_pickle_bytes(
Expand Down
6 changes: 4 additions & 2 deletions modelscan/scanners/saved_model/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from modelscan.issues import Issue, IssueCode, IssueSeverity, OperatorIssueDetails
from modelscan.scanners.scan import ScanBase, ScanResults
from modelscan.model import Model
from modelscan.settings import DefaultModelFormats
from modelscan.settings import SupportedModelFormats

logger = logging.getLogger("modelscan")

Expand All @@ -32,7 +32,9 @@ def scan(
self,
model: Model,
) -> Optional[ScanResults]:
if DefaultModelFormats.TENSORFLOW not in model.get_context("formats"):
if SupportedModelFormats.TENSORFLOW.value not in [
format_property.value for format_property in model.get_context("formats")
]:
return None

dep_error = self.handle_binary_dependencies()
Expand Down
33 changes: 19 additions & 14 deletions modelscan/settings.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
import tomlkit

from enum import Enum
from typing import Any

from modelscan._version import __version__


class DefaultModelFormats(Enum):
TENSORFLOW = "tensorflow"
KERAS_H5 = "keras_h5"
KERAS = "keras"
NUMPY = "numpy"
PYTORCH = "pytorch"
PICKLE = "pickle"
class Property:
def __init__(self, name: str, value: Any) -> None:
self.name = name
self.value = value


class SupportedModelFormats:
TENSORFLOW = Property("TENSORFLOW", "tensorflow")
KERAS_H5 = Property("KERAS_H5", "keras_h5")
KERAS = Property("KERAS", "keras")
NUMPY = Property("NUMPY", "numpy")
PYTORCH = Property("PYTORCH", "pytorch")
PICKLE = Property("PICKLE", "pickle")
Comment on lines +8 to +20
Copy link
Contributor

@CandiedCode CandiedCode May 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Enum already supports name and value.

I'd recommend adding support for str, if you want to enable additional behavior like f string support.

Screenshot 2024-05-23 at 4 53 26 PM

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replacing the enum since it's not extendable, in the current enum state other applications adding their own scanners can't add additional model formats or issue codes



DEFAULT_REPORTING_MODULES = {
Expand Down Expand Up @@ -70,12 +75,12 @@ class DefaultModelFormats(Enum):
"middlewares": {
"modelscan.middlewares.FormatViaExtensionMiddleware": {
"formats": {
DefaultModelFormats.TENSORFLOW: [".pb"],
DefaultModelFormats.KERAS_H5: [".h5"],
DefaultModelFormats.KERAS: [".keras"],
DefaultModelFormats.NUMPY: [".npy"],
DefaultModelFormats.PYTORCH: [".bin", ".pt", ".pth", ".ckpt"],
DefaultModelFormats.PICKLE: [
SupportedModelFormats.TENSORFLOW: [".pb"],
SupportedModelFormats.KERAS_H5: [".h5"],
SupportedModelFormats.KERAS: [".keras"],
SupportedModelFormats.NUMPY: [".npy"],
SupportedModelFormats.PYTORCH: [".bin", ".pt", ".pth", ".ckpt"],
SupportedModelFormats.PICKLE: [
".pkl",
".pickle",
".joblib",
Expand Down
17 changes: 9 additions & 8 deletions modelscan/skip.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import logging
from enum import Enum

from modelscan.settings import Property

logger = logging.getLogger("modelscan")


class SkipCategories(Enum):
SCAN_NOT_SUPPORTED = 1
BAD_ZIP = 2
MODEL_CONFIG = 3
H5_DATA = 4
NOT_IMPLEMENTED = 5
MAGIC_NUMBER = 6
class SkipCategories:
SCAN_NOT_SUPPORTED = Property("SCAN_NOT_SUPPORTED", 1)
BAD_ZIP = Property("BAD_ZIP", 2)
MODEL_CONFIG = Property("MODEL_CONFIG", 3)
H5_DATA = Property("H5_DATA", 4)
NOT_IMPLEMENTED = Property("NOT_IMPLEMENTED", 5)
MAGIC_NUMBER = Property("MAGIC_NUMBER", 6)
Comment on lines +9 to +15
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would IntEnum be better here?

Screenshot 2024-05-23 at 5 03 02 PM

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same problem with not being extendable



class Skip:
Expand All @@ -31,7 +32,7 @@ class ModelScanSkipped:
def __init__(
self,
scan_name: str,
category: SkipCategories,
category: Property,
message: str,
source: str,
) -> None:
Expand Down